[{"data":1,"prerenderedAt":-1},["ShallowReactive",2],{"project-1245":3},{"id":4,"name":5,"fullName":6,"owner":7,"repo":5,"description":8,"homepage":9,"htmlUrl":10,"language":11,"languages":10,"totalLinesOfCode":10,"stars":12,"forks":13,"watchers":14,"openIssues":15,"contributorsCount":16,"subscribersCount":16,"size":16,"stars1d":17,"stars7d":18,"stars30d":19,"stars90d":16,"forks30d":16,"starsTrendScore":18,"compositeScore":20,"rankGlobal":10,"rankLanguage":10,"license":21,"archived":22,"fork":22,"defaultBranch":23,"hasWiki":22,"hasPages":24,"topics":25,"createdAt":10,"pushedAt":10,"updatedAt":33,"readmeContent":34,"aiSummary":35,"trendingCount":16,"starSnapshotCount":16,"syncStatus":17,"lastSyncTime":36,"discoverSource":37},1245,"pyptx","patrick-toulme\u002Fpyptx","patrick-toulme","A Python DSL to write Nvidia PTX for Hopper and Blackwell in JAX and PyTorch","https:\u002F\u002Fpyptx.dev\u002F",null,"Python",310,25,5,1,0,2,6,20,4.24,"Apache License 2.0",false,"main",true,[26,27,28,29,30,31,32],"blackwell","hopper","jax","nvidia","nvidia-gpu","ptx","pytorch","2026-06-12 02:00:25","\u003Cp align=\"center\">\n  \u003Cimg src=\"docs\u002Fassets\u002Fpyptx-logo.png\" alt=\"pyptx\" width=\"520\">\n\u003C\u002Fp>\n\n# pyptx\n\n> Write PTX kernels in Python. Launch them from `jax.jit`, PyTorch, and `torch.compile`.\n\n`pyptx` is a Python DSL for handwritten PTX on NVIDIA Ampere (sm_80),\nAda (sm_89), Hopper (sm_90a), Blackwell datacenter (sm_100a), and\nBlackwell workstation (sm_120). Pre-Ampere targets like Turing (sm_75,\nT4) work for kernels that stay within the sm_75 ISA — anything using\n`cp.async`, `mbarrier`, `bf16`, `wgmma`, `tcgen05`, or TMA needs an\nAmpere-or-newer card.\n\nOne call = one instruction. No optimizer, no autotuner, no tensor IR between\nthe Python function and the PTX it emits.\n\n- explicit registers, predicates, barriers, shared memory\n- Ampere: `mma.sync` (m16n8k{8,16,32}), `cp.async`, `ldmatrix`, SMEM staging\n- Hopper: WGMMA, TMA 2D\u002F3D with multicast, mbarriers, cluster launch\n- Blackwell: `tcgen05.mma` \u002F `.ld`, TMEM, SMEM descriptors, warp specialization\n- callable from JAX, PyTorch eager, and `torch.compile`\n- `arch=\"auto\"` picks the right target for the current GPU at trace time\n  (validated on T4, A100, L4, H100, B200, RTX Pro 6000 Blackwell)\n- real PTX parser + emitter + **transpiler** — round-trips 218+ real PTX files byte-identical\n\nDocs: [pyptx.dev](https:\u002F\u002Fpyptx.dev) · Examples:\n[`examples\u002Fampere\u002F`](examples\u002Fampere),\n[`examples\u002Fhopper\u002F`](examples\u002Fhopper),\n[`examples\u002Fblackwell\u002F`](examples\u002Fblackwell) ·\nAPI: [pyptx.dev\u002Fapi](https:\u002F\u002Fpyptx.dev\u002Fapi\u002F)\n\n---\n\n## Install\n\n| Command | What you get |\n| --- | --- |\n| `pip install pyptx` | DSL, parser, emitter, transpiler (no GPU runtime) |\n| `pip install 'pyptx[torch]'` | + PyTorch eager and `torch.compile` launch path |\n| `pip install 'pyptx[jax]'` | + `jax.jit` launch path via typed FFI |\n| `pip install 'pyptx[all]'` | + both PyTorch and JAX |\n\nTip: `pip install ninja` so the PyTorch C++ extension JIT-builds on first\nlaunch (drops dispatch overhead from ~34 µs to ~14 µs).\n\n## Performance\n\n### Blackwell (B200, bf16)\n\n| Kernel | Shape | pyptx | cuBLAS | best \u002F cuBLAS |\n| --- | --- | --- | --- | --- |\n| **GEMM** (`tcgen05.mma`, 4-stage pipeline, 1SM) | 8192³ | **1240 TFLOPS** | 1610 | 77% |\n| **GEMM** (1SM) | 4096³ | **1194 TFLOPS** | 1532 | 78% |\n| **GEMM 2SM** (`cta_group::2`, 5-stage) | 2048³ | **649 TFLOPS** (beats 1SM) | 1006 | 64% |\n| **Grouped GEMM** (tcgen05, MoE) | G=4 M=2048 N=256 K=2048 | **401 TFLOPS** | torch ref | **~10.0×** |\n| **RMS norm \u002F Layer norm \u002F SwiGLU** | maintained Blackwell ports | benchmarked | torch ref | see kernel suite |\n\n### Hopper (H100 SXM5, bf16 \u002F f32)\n\n| Kernel | Shape | pyptx | vs reference |\n| --- | --- | --- | --- |\n| **GEMM** (wgmma, warp-specialized) | 8192³ | **815 TFLOPS** | beats cuBLAS ≥ 6K |\n| **Grouped GEMM** (bf16→f32) | G=8 M=K=2048 | **104 TFLOPS** | — |\n| **RMS norm** (f32) | B=2048 N=8192 | 2.6 TB\u002Fs (88% HBM) | **3.9×** torch |\n| **Layer norm** (f32) | B=2048 N=8192 | 2.5 TB\u002Fs (83% HBM) | **1.5×** `F.layer_norm` |\n| **SwiGLU** (f32) | M=2048 F=8192 | 2.8 TB\u002Fs (94% HBM) | **1.6×** `F.silu(g)*u` |\n| **Softmax** (f32, row-wise) | B=2048 N=8192 | 2.8 TB\u002Fs (95% HBM) | **1.16×** `torch.softmax` |\n| **Flash attention** (bf16) | M=N=4096, HD=64 | 88 µs | **3.0×** naive torch |\n\n### Ampere (A100 80GB, bf16 \u002F f32)\n\n| Kernel | Shape | pyptx | vs reference |\n| --- | --- | --- | --- |\n| **GEMM** (`ldmatrix.x4` + `cp.async` 4-stage + register frag double-buffer + XOR swizzle + serpentine `mma.sync`) | 4096³ bf16 | **162 TFLOPS** | cuBLAS 223 TFLOPS (**73%**) |\n| **GEMM** (same kernel) | 2048³ bf16 | **108 TFLOPS** | cuBLAS 158 TFLOPS (68%) |\n| **GEMM** (simple `mma.sync` + 2-stage pipeline, teaching kernel) | 4096³ bf16 | 64 TFLOPS | cuBLAS 230 TFLOPS (28%) |\n| **RMS norm** (f32) | B=2048 N=8192 | 928 GB\u002Fs | **2.2×** torch |\n| **SwiGLU** (f32) | M=2048 F=8192 | 1.33 TB\u002Fs | **1.35×** `F.silu(g)*u` |\n| **Layer norm** (f32) | B=2048 N=8192 | 916 GB\u002Fs | 0.89× `F.layer_norm` (torch's fused kernel is hard to beat) |\n\nA100 numbers reproduce via `python benchmarks\u002Fbench_ampere_kernels.py`.\nThe high-perf A100 GEMM follows the CUTLASS SM80 \u002F MatmulTutorial v15\ndesign pattern: 128×128×32 CTA tile, 4 warps in 2×2 owning 64×64\noutput sub-tiles each, warp-collective `ldmatrix.x4` for SMEM→register\nfragment loads, **4-stage** `cp.async` ring buffer (3 in-flight),\n**register fragment double-buffering** that pre-loads the next\nK-iter's first K-block during the current iter's last `mma`,\n**CUTLASS XOR swizzle** (`atom ^= row & 3`) on all SMEM paths to\neliminate ldmatrix bank conflicts, **serpentine N-fragment order** for\nadjacent-mma operand reuse, and per-thread offset hoisting so each\ninner-loop ldmatrix is one `add` instead of 5+ ops. 64\n`mma.sync.m16n8k16` per warp per K-iter (256 per CTA per K-iter). We\nhaven't spent much time tuning this kernel — the 27% remaining gap is\naddressable (persistent \u002F stream-K scheduling, more aggressive\ninstruction-level overlap, autotuned tile sizes). See\n[`examples\u002Fampere\u002Fgemm_highperf_ampere.py`](examples\u002Fampere\u002Fgemm_highperf_ampere.py)\nfor the full kernel.\n\nFull benchmark tables + reproduction commands:\n[pyptx.dev\u002Fperformance](https:\u002F\u002Fpyptx.dev\u002Fperformance\u002F).\n\nPyTorch dispatch tiers:\n\n- **CUDA graph replay**: ~4 µs per launch\n- **Turbo eager**: ~14 µs (cached C++ extension)\n- **`torch.compile`**: ~14–22 µs (custom_op path)\n\n---\n\n## What it looks like\n\n```python\nfrom pyptx import kernel, reg, smem, ptx, Tile\nfrom pyptx.types import bf16, f32\n\n@kernel(\n    in_specs=(Tile(\"M\", \"K\", bf16), Tile(\"K\", \"N\", bf16)),\n    out_specs=(Tile(\"M\", \"N\", f32),),\n    grid=lambda M, N, K: (N \u002F\u002F 64, M \u002F\u002F 64),\n    block=(128, 1, 1),\n    arch=\"sm_90a\",\n)\ndef gemm(A, B, C):\n    sA = smem.wgmma_tile(bf16, (64, 16), major=\"K\")\n    sB = smem.wgmma_tile(bf16, (16, 64), major=\"MN\")\n    acc = reg.array(f32, 32)\n    # ... TMA loads + ptx.wgmma.mma_async(...) — each call emits exactly one PTX instruction\n```\n\nEvery `ptx.*` call is a single PTX instruction. `print(gemm.ptx())` shows\nexactly what you wrote.\n\n## One kernel, three runtime paths\n\nThe same kernel object works in JAX, PyTorch eager, and `torch.compile`:\n\n```python\n# PyTorch eager\nout = gemm(a, b)\n\n# torch.compile\nout = torch.compile(gemm)(a, b)\n\n# JAX jit (lowers through typed FFI)\nout = jax.jit(gemm)(a, b)\n```\n\nUnder the hood the PTX is JITed through `cuModuleLoadData`, registered\nwith a ~150-line C++ launch shim, and dispatched from PyTorch via\n`torch.library.custom_op` or from JAX via `jax.ffi.ffi_call`.\n\n---\n\n## Transpile existing PTX into pyptx\n\n`pyptx` is also a real PTX-to-Python transpiler. Feed it output from\n`nvcc`, Triton, Pallas, or any other source:\n\n```bash\npython -m pyptx.codegen kernel.ptx --sugar --name my_kernel > my_kernel.py\n```\n\n`--sugar` demangles names, raises spin-loops into `ptx.loop(...)`, collapses\nmbarrier-wait blocks, and groups expression chains. Round-trips are\n**byte-identical** on 218+ corpus files (CUTLASS, Triton, fast.cu, DeepGEMM,\nThunderKittens, LLVM tests).\n\nThe **815 TFLOPS** Hopper GEMM in `examples\u002Fhopper\u002Fgemm_highperf_hopper.py` is\nexactly this workflow applied to\n[fast.cu's kernel12](https:\u002F\u002Fgithub.com\u002Fpranjalssh\u002Ffast.cu).\n\n---\n\n## Start here\n\nAmpere (sm_80):\n\n- `examples\u002Fampere\u002Frms_norm.py` \u002F `layer_norm.py` \u002F `swiglu.py` \u002F\n  `softmax.py` — maintained Hopper kernels retargeted to `sm_80`.\n- `examples\u002Fampere\u002Fgemm.py` — single-warp `mma.sync.aligned.m16n8k16` bf16\n  GEMM, no SMEM staging. The minimal end-to-end Ampere tensor-core path.\n- `examples\u002Fampere\u002Fgemm_pipelined.py` — `cp.async` 2-stage SMEM ring buffer\n  + `mma.sync` on a 64×64 CTA tile (per-thread `ld.shared`, no `ldmatrix`).\n  The first-step pipelined kernel (~64 TFLOPS at 4096³).\n- `examples\u002Fampere\u002Fgemm_highperf_ampere.py` — production-leaning A100 GEMM\n  following CUTLASS SM80 + MatmulTutorial v15. 128×128×32 CTA tile, 4\n  warps in 2×2 owning 64×64 each, `ldmatrix.x4`, **4-stage** `cp.async`\n  pipeline, **register frag double-buffering** across K-iters, **XOR\n  swizzle** + serpentine `mma`, 64 `mma.sync` per warp per K-iter.\n  **162 TFLOPS at 4096³ bf16** = 73% of cuBLAS (**2.5× the simpler\n  `gemm_pipelined.py`**). Bit-exact through 4096³.\n- `benchmarks\u002Fbench_ampere_kernels.py` — A100 RMSNorm, LayerNorm, SwiGLU,\n  and GEMM benchmark suite.\n\nHopper (sm_90a):\n\n- `examples\u002Fhopper\u002Frms_norm.py` — simplest real kernel, v4 loads + warp reduce\n- `examples\u002Fhopper\u002Fgrouped_gemm.py` — multi-k WGMMA for MoE shapes\n- `examples\u002Fhopper\u002Fgemm_highperf_hopper.py` — warp-specialized 815 TFLOPS GEMM\n\nBlackwell (sm_100a):\n\n- `examples\u002Fblackwell\u002Ftcgen05_suite.py` — 13 isolated tcgen05 primitives\n  (alloc, MMA, ld, commit\u002Ffence, GEMM probes). Run this first on a B200\n  to verify the runtime stack.\n- `examples\u002Fblackwell\u002Fgemm_highperf_blackwell.py` — `build_gemm`\n  (1SM, 4-stage ring buffer, 1.24 PFLOPS at 8192³ bf16) and\n  `build_gemm_2sm` (2SM `cta_group::2` cooperative MMA, 5-stage).\n- `examples\u002Fblackwell\u002Fgemm_experimental_blackwell.py` — persistent and\n  Pallas-style experimental GEMM paths, plus the no-TMA tcgen05 debug GEMM.\n- `examples\u002Fblackwell\u002Fgrouped_gemm.py` — G-problem MoE grouped GEMM on\n  top of the same `tcgen05.mma` mainloop, bit-exact against\n  `einsum(\"gmk,gkn->gmn\")` through G=8 M=1024 N=128 K=1024.\n- `examples\u002Fblackwell\u002Frms_norm.py` \u002F `layer_norm.py` \u002F `swiglu.py` —\n  Hopper kernels re-targeted to `sm_100a`.\n- `benchmarks\u002Fbench_blackwell_gemm.py` — reproduce the 1SM + 2SM +\n  cuBLAS table above.\n- `benchmarks\u002Fbench_blackwell_kernels.py` — Blackwell grouped GEMM,\n  RMSNorm, LayerNorm, and SwiGLU benchmark suite.\n\nDocs:\n\n- [Getting Started](https:\u002F\u002Fpyptx.dev\u002Fgetting-started\u002F)\n- [Performance](https:\u002F\u002Fpyptx.dev\u002Fperformance\u002F)\n- [Debugging](https:\u002F\u002Fpyptx.dev\u002Fguides\u002Fdebugging\u002F)\n- [vs Triton\u002FCUTLASS\u002FPallas](https:\u002F\u002Fpyptx.dev\u002Fcomparison\u002F)\n\n## Status\n\n0.1.0, pre-launch. Scope:\n\n- handwritten PTX DSL with full Hopper ISA (wgmma, TMA 2D\u002F3D, mbarriers, cluster)\n- Blackwell `tcgen05` ISA (alloc, `mma.kind::f16\u002Ftf32\u002Ff8`, `ld`\u002F`st`,\n  commit, fence) with instruction-descriptor + SMEM-descriptor helpers\n- PTX parser \u002F emitter with 218+ corpus round-trip tests\n- PTX → Python transpiler with sugar pass\n- JAX runtime integration (typed FFI)\n- PyTorch eager + `torch.compile` + CUDA graph replay\n- C++ dispatch extension for low-overhead launches\n- GMMA\u002FUMMA SMEM swizzle helpers (B32 \u002F B64 \u002F B128, CuTe-compatible `Swizzle\u003CB,4,3>`)\n- PyTorch autograd via `differentiable_kernel`\n\n## License\n\nApache-2.0. See [LICENSE](LICENSE).\n","pyptx 是一个用于在 Python 中编写 NVIDIA PTX 内核的领域特定语言，支持 Hopper 和 Blackwell 架构，并可在 JAX 和 PyTorch 中调用。项目的核心功能包括对显式寄存器、谓词、屏障及共享内存的支持，以及针对 Ampere、Hopper 和 Blackwell 系列 GPU 的特定指令集特性。它允许开发者直接从 Python 函数生成 PTX 代码，无需中间优化或自动调优步骤。适用于需要高性能计算且能够充分利用 NVIDIA 最新 GPU 特性的场景，如深度学习模型训练、大规模矩阵运算等。通过简洁的 API 设计和高效的执行路径，pyptx 在多个基准测试中展示了接近甚至超过 cuBLAS 的性能表现。","2026-06-11 02:42:34","CREATED_QUERY"]