[{"data":1,"prerenderedAt":-1},["ShallowReactive",2],{"project-78207":3},{"id":4,"name":5,"fullName":6,"owner":7,"repo":5,"description":8,"homepage":9,"htmlUrl":10,"language":11,"languages":10,"totalLinesOfCode":10,"stars":12,"forks":13,"watchers":14,"openIssues":15,"contributorsCount":16,"subscribersCount":16,"size":16,"stars1d":16,"stars7d":17,"stars30d":18,"stars90d":16,"forks30d":16,"starsTrendScore":16,"compositeScore":19,"rankGlobal":10,"rankLanguage":10,"license":10,"archived":20,"fork":20,"defaultBranch":21,"hasWiki":22,"hasPages":20,"topics":23,"createdAt":10,"pushedAt":10,"updatedAt":24,"readmeContent":25,"aiSummary":26,"trendingCount":16,"starSnapshotCount":16,"syncStatus":27,"lastSyncTime":28,"discoverSource":29},78207,"coda-kernels","HanGuo97\u002Fcoda-kernels","HanGuo97","CODA: Rewriting Transformer Blocks as GEMM-Epilogue Programs","https:\u002F\u002Farxiv.org\u002Fabs\u002F2605.19269",null,"Python",202,21,3,1,0,7,139,4.03,false,"main",true,[],"2026-06-12 02:03:46","# CODA: GPU Kernels as GEMM-plus-Epilogue Programs\n\n\u003Cp align=\"center\">\n  \u003Cimg src=\"figs\u002Ficon.jpg\" width=\"350\" \u002F>\n\u003C\u002Fp>\n\n\u003Cp align=\"center\">\n  \u003Ca href=\"https:\u002F\u002Farxiv.org\u002Fabs\u002F2605.19269\">\u003Cimg src=\"https:\u002F\u002Fimg.shields.io\u002Fbadge\u002FarXiv-2605.19269-b31b1b.svg\" alt=\"arXiv\">\u003C\u002Fa>\n\u003C\u002Fp>\n\n**CODA** is a GPU kernel abstraction that expresses Transformer operators as GEMM-plus-epilogue programs, fusing normalization, activations, residual updates, and reductions into the GEMM output tile before it is written to global memory, combining framework-level productivity with hardware-level efficiency. CODA is built on [CUTLASS CuTeDSL](https:\u002F\u002Fgithub.com\u002FNVIDIA\u002Fcutlass) and targets NVIDIA Hopper (H100) GPUs.\n\n\u003Cp align=\"center\">\n  \u003Cimg src=\"figs\u002Freparameterization.png\" width=\"700\" \u002F>\n\u003C\u002Fp>\n\n> **A note on naming.** The implementation of CODA has historically been called **Rapier**, a collection of GEMM-plus-epilogue primitives built on top of CuTeDSL. The name nods to CUTLASS: a slimmer, more focused blade of the same lineage, fitting for a constrained GEMM-plus-epilogue interface.\n\n## Quick Start\n\n> [!NOTE]\n> We autotune each kernel the first time it sees a new input configuration (shape, dtype, etc.), so the initial call may take a while.\n\n### Kernel level\n\nIndividual GEMM-plus-epilogue kernels are in `kernels\u002Fgens\u002Fepilogue\u002F`. The base pattern for `gemm_residual_rmsnorm_gemm` (no extra epilogue) uses two kernels in sequence:\n\n```python\nimport torch\nfrom kernels.gens import gpt as gens\nfrom models.ops import compute_rstd\n\nM, K, N = 4096, 4096, 4096\ndtype = torch.bfloat16\ndevice = \"cuda\"\n\n# tile size for partial reductions; autotuned when used inside a full block\nblock_size = 128\n\ny   = torch.randn(M, K, dtype=dtype, device=device)   # attention output\nx   = torch.randn(M, N, dtype=dtype, device=device)   # residual\nw_a = torch.randn(K, N, dtype=dtype, device=device)   # attention out-proj weight\nw_b = torch.randn(N, N, dtype=dtype, device=device)   # MLP gate+up weight\nw_n = torch.randn(N,    dtype=dtype, device=device)   # MLP RMSNorm weight\n\n# Kernel 1: attention out-proj + residual add + partial RMS norm + norm weight scaling\n#   D = y @ w_a + x             (M, N)           -- out-proj with residual add\n#   S = partial mean(D**2)      (M, num_blocks)  -- partial RMS norm stats\n#   O = D * w_n                 (M, N)           -- norm weight scaling\nD, S, O = gens.gemm_residual_partial_rmsnorm(A=y, B=w_a, C=x, W=w_n, block_size=block_size)\n\n# per-row rstd, shape (M,)\nR = compute_rstd(s=S, eps=1e-6, use_quack=False)\n\n# Kernel 2: RMS norm + MLP gate+up projection + SwiGLU\n#   D = O @ w_b * R             (M, N)          -- normalized gate+up pre-activation\n#   O = silu(gate) * up         (M, N \u002F\u002F 2)     -- SwiGLU output\nD, O = gens.gemm_rmsnorm_swiglu(A=O, B=w_b, R=R)\n```\n\n### Ops level\n\n`models\u002Fops.py` provides high-level fused ops that cover full Transformer blocks (excluding attention). Each op represents a reparameterized Transformer layer, spanning from the attention output projection through the MLP to the QKV projection of the next layer.\n```python\nimport torch\nfrom models import ops\n\n# Forward pass through a Transformer block (excluding attention)\nx_out, qkv = ops.layer(\n    x0=x0,          # residual stream input\n    y0=y0,          # attention output\n    w0=w0,          # attention out-proj weight\n    w1=w1,          # MLP gate+up weight\n    w2=w2,          # MLP down weight (next block)\n    w3=w3,          # QKV projection weight (next block)\n    wn0=wn0,        # RMS norm weight (post-attention)\n    wn1=wn1,        # RMS norm weight (pre-QKV)\n    cos_sin=cos_sin,\n    cos=cos,\n    sin=sin,\n    num_heads=num_heads,\n    head_dim=head_dim,\n    eps=1e-6,\n    transpose=True,\n    backend=\"rapier\",\n    use_compile=True,\n)\n```\n\n## Writing a New Epilogue\n\nThe CODA GEMM mainloop is fixed; an epilogue plugs into it by overriding a handful of callback methods on `EpilogueVisitorTree` (defined in [rapier\u002Fepilogue\u002Fbase.py](rapier\u002Fepilogue\u002Fbase.py)). The mainloop produces a GEMM accumulator tile `tRS_rAcc` in registers, walks it sub-tile by sub-tile, and invokes the epilogue at well-defined hook points before staging the result through shared memory and storing it out via TMA.\n\n### Mainloop \u002F epilogue interaction\n\n```python\n# once per output tile, after the GEMM mainloop produces tRS_rAcc\nevt.consumer_begin(...)              # load per-tile inputs gmem -> smem\nevt.producer_begin(...)              # set up TMA producer state (if any)\n\nfor sub_tile in epi_tiles:\n    evt.consumer_begin_loop(...)     # load smem -> registers for this sub-tile\n    evt.producer_tma_load(...)       # issue async TMA loads (if any)\n    evt.consumer_visit(tRS_rD, ...)  # MUTATE the accumulator tile: the core op\n    # mainloop: cast and stage tRS_rD into smem\n    evt.consumer_smem_store(...)     # optional extra smem writes (e.g. partial reductions)\n    # mainloop: TMA-store smem -> gmem\n    evt.consumer_tma_store(...)      # optional post-store callback\n    evt.consumer_end_loop(...)\n\nevt.consumer_end(...)                # post-loop finalization\n```\n\nPer-tile and per-sub-tile state (smem views, register tensors) flows between these calls through return values that the mainloop threads forward as arguments.\n\n### Methods\n\n| Method | What it does |\n|--------|--------------|\n| `to_underlying_arguments` | Converts host-side `EpilogueArguments` into device-side `EpilogueParams` (adds alignment hints, etc.). Called before the kernel launch. |\n| `get_smem_struct` \u002F `get_smem_tensors` \u002F `get_smem_bytes_per_stage` | Declare the shared memory buffers this epilogue needs (dtypes + sizes), build CuTe tensor views over them, and report per-stage byte budgets. |\n| `consumer_begin` | Once per CTA output tile: load per-tile inputs (e.g. an `R` column vector for RMS norm) from global to shared memory and produce partitioned smem views. |\n| `producer_begin` \u002F `producer_tma_load` | Set up and drive the TMA producer pipeline for inputs loaded asynchronously per sub-tile (e.g. a residual matrix). No-ops by default. |\n| `consumer_begin_loop` | Per epilogue sub-tile: copy the relevant slice of smem into registers, ready to be combined with the accumulator. |\n| `consumer_visit` | **The core operation.** Mutates the accumulator register tile `tRS_rD` in place; this is where the actual elementwise \u002F reduction math happens. Receives `tRS_rD` in accumulator dtype (typically fp32); the cast to output dtype happens afterwards in the mainloop. |\n| `consumer_smem_store` | Optional extra writes to shared memory after `tRS_rD` has been staged into smem but before the TMA store (e.g. writing partial reduction results). |\n| `consumer_tma_store` | Callback fired immediately after the mainloop TMA-stores the tile to global memory; useful for chaining additional global writes. |\n| `consumer_end_loop` \u002F `consumer_end` | Per sub-tile and per CTA-tile cleanup hooks. |\n\n### Example: per-row scaling\n\n`EVTRMSNormScale` in [kernels\u002Fgens\u002Fepilogue\u002Fkernel_1.py](kernels\u002Fgens\u002Fepilogue\u002Fkernel_1.py) multiplies the GEMM accumulator by a per-row scalar `R` (the RMS norm reciprocal std dev). The load-and-multiply core looks like:\n\n```python\n@cute.jit\ndef consumer_begin(self, ..., epi_params, epi_tensors_smem):\n    sColVec = epi_tensors_smem.sColVec\n    # take this CTA's slice of the global R vector, then async-copy gmem -> smem\n    gColVec = cute.local_tile(epi_params.mColVec, (tile_M,), (m_idx,))\n    memory_utils.g2s_copy_1d(src=gColVec, dst=sColVec, ...)\n    # broadcast the column along N (stride 0), then partition across threads\n    sColVec_view = cute.make_tensor(\n        sColVec.iterator,\n        cute.make_layout((tile_M, tile_N), stride=(1, 0)),\n    )\n    tDsColVec = partition_for_epilogue(sColVec_view)\n    # wait for cp.async, then sync the consumer warps\n    cute.arch.cp_async_commit_group()\n    cute.arch.cp_async_wait_group(0)\n    epi_barrier.arrive_and_wait()\n    return self.EpilogueTensors(tDsColVec=tDsColVec)\n\n@cute.jit\ndef consumer_begin_loop(self, ..., epi_coord, epi_tensors):\n    # select this sub-tile's slice of the smem view, then copy smem -> registers (acc dtype)\n    tDsColVec_cur = epi_tensors.tDsColVec[..., epi_coord]\n    tDrColVec_cvt = memory_utils.s2r_copy_1d(tDsColVec_cur, dtype=self.acc_dtype)\n    return self.EpilogueTensorsLoop(tDrColVec_epi=tDrColVec_cvt), ...\n\n@cute.jit\ndef consumer_visit(self, tRS_rD, ..., epi_tensors_loop):\n    # per-row scaling: multiply each accumulator element by the matching R value\n    tDrColVec_epi = epi_tensors_loop.tDrColVec_epi\n    for i in cutlass.range_constexpr(cute.size(tDrColVec_epi)):\n        tRS_rD[i] = tRS_rD[i] * tDrColVec_epi[i]\n    return epi_tensors_loop\n```\n\n## Repository Structure\n\n```\ncoda-kernels\u002F\n├── models\u002F          # High-level API\n│   ├── ops.py       # CODA layer implementations (forward + backward)\n│   └── ops2.py      # Corresponding PyTorch implementations\n├── kernels\u002F\n│   ├── gens\u002F        # LLM-authored CuTeDSL kernel implementations\n│   ├── refs\u002F        # PyTorch reference implementations\n│   ├── tests\u002F\n│   └── benchmarks\u002F\n└── rapier\u002F          # CODA kernel infrastructure\n    ├── gemm\u002F        # WGMMA GEMM kernels and PyTorch wrapper\n    ├── epilogue\u002F    # Composable epilogue visitors\n    ├── ops\u002F         # Low-level utilities\n    ├── examples\u002F    # Standalone usage examples\n    └── docs\u002F        # Docs for LLM\n```\n\n### Epilogue Visitors (`rapier\u002Fepilogue\u002F`)\n\n| Module | Description |\n|--------|-------------|\n| `bias` | Row\u002Fcolumn bias addition |\n| `reduction` | Block-level row\u002Fcolumn reductions (store, store-2X, load variants) |\n| `activation` | Dual-output activations: elementwise, pairwise, contraction, expansion |\n| `matrix` | TMA-pipelined matrix load with residual add; 2X paired-tile variant |\n| `cross_entropy` | Online softmax + target logit selection, fused into the output tile |\n| `composite` | Chains multiple visitors into a single unified epilogue |\n","CODA 是一个GPU内核抽象项目，它将Transformer操作符表达为GEMM-plus-epilogue程序，通过在GEMM输出块写入全局内存之前融合归一化、激活、残差更新和缩减等操作，结合了框架级别的生产力与硬件级别的效率。基于CUTLASS CuTeDSL构建，CODA特别针对NVIDIA Hopper (H100) GPU进行了优化。其核心功能包括自动调优每个内核以适应新的输入配置，并提供了从基础的GEMM-plus-epilogue到完整的Transformer块（不包括注意力机制）的多层次实现。该项目非常适合需要高效执行Transformer模型推理或训练任务的场景，尤其是在追求高性能计算的同时希望保持代码简洁性的开发者群体中具有广泛的应用前景。",2,"2026-06-06 03:57:39","CREATED_QUERY"]