Stage B: two MMAs + identity softmax — crash fixed, softmax output still wrong

Key fixes:
- PipelineUmmaAsync consumer group: 32*4=128 threads (not 4 warps)
- TMEM offsets computed from find_tmem_tensor_col_offset (not hardcoded)
- P fragment from p_tmem_s.outer + make_fragment_A (matching fmha.py)
- V SMEM aliasing via recast_ptr

Status:
- Stage A: cosine 0.999999 
- Stage B: runs without crash, identity softmax cosine -0.02 
- Diagnostics: TMEM layout inspection, bisection results
This commit is contained in:
2026-05-20 20:26:25 +00:00
parent a5b48be7d5
commit 97656a5cd1
27 changed files with 8434 additions and 533 deletions

516
README.md
View File

@@ -1,319 +1,245 @@
# NVFP4 MegaMoE Kernel
# DeepSeek-V4 NVFP4 Kernel Suite
Native NVFP4 inference stack for DeepSeek-V4 on NVIDIA Blackwell (SM100). CuTeDSL kernels for the entire model — MoE experts, shared experts, attention projections — running in native NVFP4 with zero dequantization overhead.
CuTeDSL kernels for DeepSeek-V4 (Blackwell B200, SM100). All kernels use `cutlass.cute` (CuTeDSL) with Blackwell tensor cores.
## ⚠️ THE #1 RULE
**WE OWN ALL OUR KERNELS. WE DO NOT PATCH vLLM.**
vLLM's internal kernels (FlashMLA, fp8_ds_mla, fused compressor, Triton indexer) are deeply coupled. You cannot swap one piece and expect the rest to work. We build our own CuTeDSL kernels, test standalone, then wire into vLLM as an attention backend.
---
## Repository Layout
**This repo (`nvfp4-megamoe-kernel`):** The kernel library — CuTeDSL kernels, bridge layer, standalone tests.
**vLLM fork (`vllm-deepseekv4-nvfp4`):** The vLLM integration — model definition, weight loading, attention backend. Lives at `/root/dsv4-nvfp4-workspace/vllm` on the B200.
**Workspace (`/root/dsv4-nvfp4-workspace`):**
- `kernel/` — clone of this repo
- `vllm/` — clone of the vLLM fork
---
## Kernel Status
### ✅ CuTeDSL NVFP4 Grouped GEMM (the building block)
`ScaledGroupedGemmKernel` in `cutedsl/kernel/moe/torch_scaled_grouped_mm.py`:
- 2D×3D scenario: A(M,K) × B(E,K,N) → C(M,N)
- Block-scaled: per-16-element FP8 scales on both A and B sides
- Global scales (per-expert) for full dynamic range
- Persistent scheduler, TMA pipelining, SMEM swizzle
- CUDAGraph-safe (workspace pre-allocated, no runtime allocations)
### ✅ Fused SwiGLU GEMM (L1 gate+up with SwiGLU in registers)
`FusedSwiGLUScaledGroupedGemmKernel` in `cutedsl/kernel/moe/fused_swiglu_grouped_mm.py`:
- Extends the base GEMM with an in-epilogue SwiGLU
- **Weight interleave**: `interleave_l1_weights()` interleaves gate/up at granularity 8 BF16
- **epi_tile=(128, 8)**: each 8-wide subtile is pure gate or pure up
- **Subtile-level pairing**: even subtiles = gate (SiLU in FP32, save to register buffer), odd subtiles = up (load silu(gate) from buffer, compute silu(gate)*up)
- Output: BF16 with interleaved [silu(gate), silu(gate)*up] at granularity 8
- **Cosine 0.988** vs BF16 reference (full MoE pipeline)
### ✅ Custom CUDA De-interleave + NVFP4 Quantize
`cutedsl/kernels/deinterleave_quantize.cu`:
- Single GPU kernel: reads fused L1 BF16 output, extracts SwiGLU from odd 8-col groups, quantizes to NVFP4
- Replaces the Python `deinterleave_l1_weights()` + `quantize_activation_nvfp4()` path
- **4.3x faster** (0.043ms vs 0.184ms for 128 tokens)
- **99.97% cosine match** with Python reference, 99.7% FP4 byte match
- Saves ~8.5ms over 60 MoE layers
### ✅ NVFP4 Linear (`cutedsl/nvfp4_linear.py`)
`CuTeDSLNvfp4Linear` — single-expert NVFP4 GEMM for shared experts and attention projections.
### ✅ GPU-Native SWA Decode Attention (CuTeDSL)
`cutedsl/native_swa_decode.py``BlackwellSWADecodeKernel`:
- CTA mapping: 1 CTA per (decode_token, q_head_group) — 8 groups × T tokens
- Q loaded into registers, KV streamed in 16-token tiles through smem
- Online softmax (max/exp/rescale/sum) across tiles
- Pre-dequantized bf16 KV (fp8 dequant done on host, fused dequant is future work)
- **Cosine 0.9999+** vs PyTorch batched SDPA reference
### ✅ GPU-Native Sparse + SWA Decode Attention (CuTeDSL)
`cutedsl/native_sparse_decode.py``BlackwellSparseDecodeKernel`:
- Same CTA mapping as SWA kernel
- Concatenated SWA + compressed KV in a single attention pass
- Sink weight merge applied on host side
- **Cosine 0.9999+** vs combined SDPA reference
- Supports both CSA (cr=4) and HCA (cr=128) layers
### ✅ Sparse Topk Metadata Kernels (C128A + C4A)
`cutedsl/kernels/sparse_topk_metadata.cu` + `cutedsl/sparse_topk_metadata.py`:
- **`build_c128a_topk_metadata`**: position-based compressed KV slot lookup via block table for C128A (cr=128) decode tokens. Maps `(position, block_table) → global compressed KV slot IDs + lengths`
- **`compute_c4a_global_topk`**: local topk index → global KV cache slot mapping via block table for C4A (cr=4) decode tokens
- Both tested: correct block table lookups, proper padding, valid length counts
- **No FlashMLA, no vLLM Triton dependency** — own CUDA kernels
### ✅ Blackwell Attention (standalone tests)
- `cutedsl/blackwell_attention.py` — KV cache write/read, full attention pipeline
- `cutedsl/csa_attention.py` — CSA (cr=4) and HCA (cr=128) sparse attention
- All standalone tests pass: KV cache (0.9997), CSA/HCA, prefill+decode (0.9998)
### ✅ CuTeDSL Warmup Compilation
`warmup_compilation()` and `warmup_fused_swiglu_compilation()` in `bridge.py`:
- Eagerly JIT-compiles GEMM kernels before model forward pass
- Uses **quantized random BF16** (via `quantize_to_nvfp4`) for warmup data
- Zero-filled FP4/FP8 causes `cudaErrorIllegalInstruction` — random bytes produce NaN in MMA dequant
- All three shapes compile successfully: L1 (48 experts, 3584×3072), L2 (48 experts, 3072×3584), Fused L1
---
## Bridge Layer (`cutedsl/bridge.py`)
Quantization, layout, kernel launch utilities:
| Function | Purpose |
|----------|---------|
| `quantize_to_nvfp4()` | BF16 → NVFP4 with global scale |
| `quantize_activation_nvfp4()` | CUDAGraph-safe quantize (pre-computed gs) |
| `quantize_weight_to_nvfp4()` | Weight quantization along K dim |
| `interleave_l1_weights()` | Gate/up interleave at granularity 8 BF16 |
| `deinterleave_l1_weights()` | Reverse the interleave |
| `deinterleave_quantize_nvfp4_cuda()` | Custom CUDA: de-interleave + quantize in one kernel |
| `make_b_k_major()` | B tensor stride conversion |
| `assemble_scales_2d_side()` / `assemble_scales_3d_side()` | Scale assembly + swizzle |
| `warmup_compilation()` | Eager JIT compilation with quantized random data (base GEMM) |
| `warmup_fused_swiglu_compilation()` | Eager JIT compilation with quantized random data (fused SwiGLU) |
| `run_nvfp4_grouped_gemm()` | Base GEMM entry point |
| `run_fused_swiglu_grouped_gemm()` | Fused SwiGLU GEMM entry point |
---
## MoE Pipeline
### Non-Fused Path
`CuTeDSLMoERunner` / `run_nvfp4_moe()`:
1. Quantize input BF16 → NVFP4 (pre-computed gs)
2. L1 GEMM: NVFP4 × NVFP4 → BF16 (gate+up interleaved)
3. De-interleave, split gate/up
4. SiLU(gate) * up → BF16 (PyTorch)
5. Re-quantize BF16 → NVFP4
6. L2 GEMM: NVFP4 × NVFP4 → BF16 (down_proj)
7. Scatter with routing weights
### Fused Path
`run_nvfp4_moe_fused()` / `CuTeDSLMoERunner(fused_swiglu=True)`:
1. Quantize input BF16 → NVFP4 (pre-computed gs)
2. **Fused L1 GEMM + SwiGLU** in kernel registers → BF16 TMA store
3. **Custom CUDA kernel**: de-interleave + NVFP4 quantize (0.043ms)
4. L2 GEMM: NVFP4 × NVFP4 → BF16 (down_proj)
5. Scatter with routing weights
**Both paths: cosine 0.988 vs BF16 reference.** Fused path is marginally more accurate (FP32 SiLU in registers vs PyTorch BF16 SiLU).
---
## Blackwell Decode Path (vLLM Integration)
The Blackwell decode path in `attention.py` routes through our own kernels:
**SWA-only layers (cr=0):** `native_swa_decode_attention` — CuTeDSL kernel
**CSA layers (cr=4):** `native_sparse_decode_attention` with topk indices from `compute_c4a_global_topk` — our CUDA kernel maps indexer local topk → global KV cache slots
**HCA layers (cr=128):** `native_sparse_decode_attention` with topk indices from `build_c128a_topk_metadata` — our CUDA kernel maps positions → compressed KV slot IDs via block table lookup
**Metadata flow:**
- `DeepseekSparseSWAMetadataBuilder` builds SWA indices + C128A buffers
- `attention.py` detects FlashMLA vs Indexer metadata at runtime
- Blackwell path reads `indexer_metadata.decode.block_table` for block table access
- No FlashMLA dependency on Blackwell
---
## Correctness Bugs Fixed (May 20, 2026)
| Bug | Issue | Fix |
|-----|-------|-----|
| C128A topk missing | `DeepseekSparseSWAMetadataBuilder` returned None for C128A topk → SWA-only fallback | `build_c128a_topk_metadata` CUDA kernel computes global slot IDs from positions + block table |
| C4A topk missing | Relied on vLLM's Triton `compute_global_topk_indices_and_lens` (not ours) | `compute_c4a_global_topk` CUDA kernel replaces it on Blackwell |
| Warmup crash | Zero-filled FP4/FP8 → `cudaErrorIllegalInstruction` in MMA hardware | Quantize random BF16 through `quantize_to_nvfp4` for mathematically consistent warmup data |
| Warmup disabled | Was commented out → lazy JIT on first forward → OOM competing with model | Re-enabled in runner.py; L1/L2/fused all compile eagerly |
| `_fused_swiglu` not initialized | `CuTeDSLMoERunner.__init__` missing `self._fused_swiglu = False` | Added initialization |
| FlashMLA assert crash | `assert flashmla_metadata is not None` crashes on Blackwell where indexer_metadata is used instead | Fixed assert to accept either |
| `_needs_token_refill` myth | cute.compile doesn't corrupt GPU memory | Removed hack |
| Zero block FP8 scale | `clamp(min=1e-8)` gives nonzero scale for zero blocks | Detect zero blocks, force FP8 scale to exact 0 |
| Underflow blocks | amax < 6×2⁻⁹ gets nonzero FP4 | Detect underflow, zero x_norm before division |
| Expert counting | Materializes 18M bool tensor | `torch.bincount` replaces O(n×E) comparison |
| Dequantize→requantize | "Supposedly lossy" | Verified 100% byte-identical round-trip |
---
## Fused SwiGLU — How It Works
### The Problem
The L1 GEMM produces (M, 2×intermediate) BF16 output with gate and up columns side by side. SwiGLU needs silu(gate)*up, producing (M, intermediate). In the unfused path, this requires:
- ~580MB BF16 write to GMEM (L1 output)
- ~290MB BF16 read back (for gate/up split + SiLU)
- 3 kernel launches + 12 quantize ops
### The Solution: Granularity-8 Weight Interleave + Subtile Pairing
**Key insight**: With `interleave_l1_weights()`, gate and up weight columns are interleaved at granularity 8 BF16. In the GEMM output, every 8 BF16 columns alternate: [gate₀-₇, up₀-₇, gate₈-₁₅, up₈-₁₅, ...].
With `epi_tile_n=8`, each epilogue subtile covers exactly 8 BF16 N-columns. So each subtile is **pure gate or pure up** — no mixing. Even subtile indices = gate, odd = up.
**The epilogue loop** processes gate/up pairs:
```
for subtile_idx in range(subtile_cnt):
acc_vec = load_accumulator(subtile_idx)
acc_vec_bf16 = acc_vec.to(bf16) # init before dynamic if
if even (gate):
silu_result = silu(acc_vec) # FP32 math
silu_gate_buf = silu_result # save to register buffer
acc_vec_bf16 = silu_result
if odd (up):
gate_vals = silu_gate_buf # from previous iteration
acc_vec_bf16 = gate_vals * acc_vec # SwiGLU
store_to_smem(acc_vec_bf16)
tma_store_to_gmem()
```
Both branches produce `acc_vec_bf16` of the same BF16 type. No runtime conditional affects tensor structure. The `silu_gate_buf` is a register buffer initialized before the loop.
**The output** has interleaved [silu(gate), silu(gate)*up] at granularity 8. The custom CUDA kernel extracts odd 8-col groups (the SwiGLU result) and quantizes to NVFP4 for the L2 GEMM.
### The `//2` Bug
`interleave_l1_weights` had `g = granularity_bf16 // 2`, correct for K-axis interleave (FP4 packing along K). But we interleave along N, where each N-column = 1 BF16 column. The `//2` was a K-axis leftover that silently gave g=4 instead of g=8. **Fixed**: `g = granularity_bf16` (no `//2`).
### CuTeDSL Runtime Conditionals
CuTeDSL **does** support runtime conditionals on register tensors — both branches must produce the same tensor type (shape, layout, dtype). The earlier "blocked by type system" framing was wrong. The real issue: the old code applied SiLU to ALL positions (just SiLU, not SwiGLU) and the mask-blending approach (`silu(both)*0.5`) is mathematically wrong. With epi_tile_n=8 and subtile-level pairing, the conditional is clean.
### The Global Scale Gotcha
The custom CUDA quantize kernel needs the **L2 activation global scale** (from the SwiGLU output), NOT the L1 input global scale. The L1 gs is based on the input magnitude (~0.1), while the SwiGLU output can be orders of magnitude larger. Passing the wrong gs causes the FP8 block scale to overflow, producing NaN. The runner pre-computes the L2 gs in `compute_activation_global_scales()` before CUDAGraph capture.
---
## Remaining Work
| What | Status | Notes |
|------|--------|-------|
| In-epilogue NVFP4 quantize (replace BF16 TMA with FP4 TMA) | 🔨 Future | Saves ~0.14ms/layer; requires register→GMEM mapping for FP4 output |
| Fuse fp8→bf16 dequant into CuTeDSL kernel | 🔨 Future | Currently pre-dequantized on host; need vectorized fp8 loads |
| CSA/HCA sink weight merge in CuTeDSL | 🔨 Future | Applied on host for now; fuse into kernel for perf |
---
## DeepSeek-V4 Architecture Notes
**NOT MLA.** DeepSeek-V4 uses:
- **CSA** (Compressed Sparse Attention, cr=4): KV compressed 4x, indexer finds top-k
- **HCA** (Heavily Compressed Attention, cr=128): KV compressed 128x, pre-computed indices
- **SWA**: Standard sliding window (window=128, last layer only)
- **mHC**: Manifold-Constrained Hyper-Connections — replaces residual connections
- **384 experts, top-6, intermediate=3072**
Compress ratios by layer: alternating 128/4, layer 60 = 0 (SWA).
---
## File Structure
## File Map
```
cutedsl/
├── bridge.py # Quantization, layout, kernel launch
├── nvfp4_linear.py # Single-expert NVFP4 GEMM runner
├── runner.py # MoE grouped GEMM runner (fused + non-fused)
├── blackwell_attention.py # KV cache + attention (standalone)
├── csa_attention.py # CSA/HCA attention
├── custom_ops.py # torch.autograd wrappers
├── moe_pipeline.py # Standalone test pipeline (fused + non-fused)
── sparse_topk_metadata.py # C128A + C4A topk metadata (Python wrapper)
├── native_swa_decode.py # GPU-native SWA decode (CuTeDSL)
├── native_sparse_decode.py # GPU-native sparse+SWA decode (CuTeDSL)
├── kernels/
│ ├── deinterleave_quantize.cu # Custom CUDA: de-interleave + NVFP4 quantize
│ └── sparse_topk_metadata.cu # Custom CUDA: C128A + C4A topk metadata
└── kernel/moe/
├── torch_scaled_grouped_mm.py # ScaledGroupedGemmKernel (the GEMM)
└── fused_swiglu_grouped_mm.py # FusedSwiGLUScaledGroupedGemmKernel
├── native_swa_decode.py # SWA decode attention — IN PROGRESS (v3 tcgen05 rewrite)
├── native_sparse_decode.py # Sparse (CSA/HCA) decode — NOT YET REWRITTEN
├── nvfp4_cutedsl.py # NVFP4 MoE runner (CuTeDSL) — WORKING
├── moe_pipeline.py # MoE fused SwiGLU pipeline — WORKING
├── blackwell_attention.py # vLLM bridge for Blackwell attention path
├── csa_attention.py # CSA/HCA sparse attention bridge
├── custom_ops.py # Custom CUDA ops registration
── kernel/
└── blockscaled_gemm/
└── dense_blockscaled_gemm_persistent.py # REFERENCE: Blackwell TMEM/tcgen05 GEMM
tests/
├── layertest.py # MoE layer test — fused + non-fused (PASS, 0.988)
├── cudagraph_test.py # CUDAGraph test (PASS)
├── test_full_layer_b200.py # All NVFP4 projections (PASS, 0.994+)
├── test_v4_attention_b200.py # All 3 attention types (PASS)
├── test_kv_cache_b200.py # KV cache (PASS, 0.9997)
├── test_sparse_attn_b200.py # CSA/HCA (PASS)
├── test_decode_attention_b200.py # Prefill+decode (PASS, 0.9998)
├── test_stage_a_v2.py # ✅ Stage A: bare Q@K^T via tcgen05.mma → TMEM → GMEM
├── test_stage_b_v7.py # 🔨 Stage B: two MMAs + identity softmax (runs, wrong output)
├── test_stage_b_minimal.py # ✅ Stage B minimal: two MMAs, no softmax (runs, NaN expected)
├── test_stage_b_pipeline_only.py # ✅ Stage B pipeline-only: PipelineUmmaAsync, no ld/st (runs, NaN expected)
├── diag_tmem.py # Diagnostic: TMEM layout inspection
├── test_stage_b_v6.py # ❌ Stage B v6 (hardcoded offsets, crashes)
├── test_stage_a_qk.py # ❌ Stage A v1 (broken, superseded by v2)
├── test_stage_a_minimal.py # ❌ Stage A minimal (broken, superseded by v2)
├── test_attention_path_b200.py # Full attention path test (uses naive BF16 attn)
└── ...
```
---
## Current Status
## Key Lessons
### ✅ Stage A: Bare Q@K^T via tcgen05.mma — COMPLETE (May 20)
1. **⛔ NEVER assume CuTeDSL GPU tensors survive JIT compilation.** `cute.compile` zeroes GPU memory. Keep index/mapping tensors on CPU.
**File**: `tests/test_stage_a_v2.py`
**Result**: Q(128,128) @ K^T(128,128) → S(128,128), cosine 0.999999
2. **⛔ NEVER nuke working code without understanding why it exists.** CUDAGraph-safe functions exist because vLLM requires CUDAGraph.
Validates the full tcgen05.mma → TMEM → epilogue → GMEM path:
- tcgen05.mma with BF16 inputs, FP32 TMEM accumulator
- TMA load for A and B (cute.nvgpu.make_tiled_tma_atom_A/B)
- TMA store for C (cpasync.CopyBulkTensorTileS2GOp)
- Warp specialization: 4 epilogue warps + 1 MMA warp + 1 TMA warp = 192 threads
- PipelineTmaUmma for AB pipeline, PipelineUmmaAsync for acc pipeline
- TmemAllocator for TMEM allocation/deallocation
- utils.gemm.sm100.epilogue_tma_store for the TMEM→reg→SMEM→TMA→GMEM epilogue
3. **⛔ NEVER fabricate facts from MEMORY.md.** Verify what "works" means before citing it.
### 🔨 Stage B: Two MMAs + Identity Softmax — IN PROGRESS (May 20)
4. **⛔ NEVER quantize a padded buffer and slice the output.** Quantize compact data, scatter into padded layout.
**Latest**: `tests/test_stage_b_v7.py`
**Status**: Kernel compiles and runs without crashing. Identity softmax produces wrong output (cosine ≈ -0.02).
5. **⛔ Silent weight drops are deadly.** vLLM's `if name not in params_dict: continue` skips weights with no warning. Replace with hard RuntimeError.
**What was fixed today:**
6. **⛔ NVFP4 is NOT suitable for attention Q×K^T.** Per-element dot products are too sensitive. Keep attention in BF16.
1. **PipelineUmmaAsync consumer group size crash (THE bug):**
`PipelineUmmaAsync` with `Agent.Thread` requires **thread count** (128), NOT warp count (4), for the consumer group. fmha.py uses `32 * len(softmax_warp_ids) = 128`. Using 4 caused `CUDA_ERROR_LAUNCH_FAILED` (not a deadlock — the barrier reached wrong threshold causing illegal TMEM access).
```python
# WRONG (caused CUDA_ERROR_LAUNCH_FAILED):
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 4)
# CORRECT (matches fmha.py):
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 128)
```
7. **⛔ NEVER touch drivers, kernels, firmware, or system packages on the B200.** The cluster costs millions. Always confirm with Mike.
2. **TMEM offset computation (no more hardcoding):**
- `s_cols = find_tmem_tensor_col_offset(tStS) = 128` — QK accumulator physical TMEM columns
- `o_cols = find_tmem_tensor_col_offset(tOtO) = 128` — PV accumulator physical TMEM columns
- `tmem_s0_offset = 0, tmem_p0_offset = 32, tmem_o0_offset = 128` — matches fmha.py
- `find_tmem_tensor_col_offset(tOrP_sliced) = 32800 = 0x8020` — 0x8000 is TMEM space tag, column offset = 32
- Total: 256 TMEM cols (verified by `get_num_tmem_alloc_cols`)
8. **⛔ CuTeDSL `if` branches must produce the same tensor type.** Both branches must yield identical (shape, layout, dtype). Initialize variables before the `if` — using values defined only inside a branch is not supported.
3. **P fragment construction (matching fmha.py):**
```python
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) # A-layout from PV MMA
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
tOrP0 = cute.make_tensor(tOrP.iterator + 2 * tmem_p0_offset, tOrP.layout)
```
Previously used `cute.composition` on C-layout — wrong, must use PV MMA's A-layout.
9. **⛔ The `//2` in interleave was a K-axis leftover.** FP4 packing is along K, not N. When interleaving along N, `g = granularity_bf16` (no `//2`). The bug silently gave granularity 4 instead of 8.
4. **V SMEM aliasing:**
V shares the same SMEM as K with a different layout interpretation:
```python
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
tCrV = pv_mma.make_fragment_B(sV) # Uses MN-major V layout
```
10. **⛔ "SiLU on all positions" is NOT SwiGLU.** SwiGLU pairs silu(gate)*up. Applying SiLU to the full (M, 2×intermediate) output is just SiLU. The pairing must be explicit.
**What's still broken:**
11. **⛔ The global scale must match the data being quantized.** Passing the L1 input gs to the SwiGLU quantize causes FP8 overflow → NaN. The gs must come from the SwiGLU output's magnitude.
The identity softmax C→A layout transform produces garbage output (cosine ≈ -0.02). The kernel runs, Stage A (Q@K^T) gives cosine 0.999999, but the full (Q@K^T)@V pipeline is wrong. The issue is in the tcgen05.ld/st identity softmax path — either the ld/st copy atoms, the register conversion, or the A-layout write positions are incorrect.
12. **⛔ NEVER use zero-filled or random-byte data for CuTeDSL warmup.** Zeros cause division-by-zero in scale dequant. Random uint8 bytes as FP4 produce NaN/Inf in MMA → `cudaErrorIllegalInstruction`. Always quantize random BF16 through `quantize_to_nvfp4` for mathematically consistent warmup data.
**Bisection results:**
- ✅ Stage B minimal (no pipeline, no softmax): runs, NaN (expected — no C→A transform)
- ✅ Stage B pipeline-only (PipelineUmmaAsync, no ld/st): runs, NaN (expected)
- 🔨 Stage B full (identity softmax): runs, cosine -0.02 (wrong — softmax transform is broken)
- All three crash with consumer_group=4, all run with consumer_group=128
13. **⛔ NEVER borrow kernels from vLLM or FlashMLA.** We own all our kernels. If we need a kernel that exists in vLLM's Triton or FlashMLA's C++, we build our own CUDA/CuTeDSL equivalent from scratch.
**TMEM layout diagnostic data:**
```
QK accumulator C fragment:
tStS.layout = ((128,128),1,1):((65536,1),0,0)
cute.size = 16384, cute.cosize = 8323200
find_tmem_tensor_col_offset = 128
PV A-fragment (P operand):
tOrP_sliced.layout = ((128,16),1,4):((65536,1),0,16)
cute.size = 8192, cute.cosize = 8323136
find_tmem_tensor_col_offset = 32800 = 0x8020 (0x8000 tag + col 32)
```
### 🔨 Stage C: Online Softmax — AFTER B
The hard part. Per the pseudocode:
- Epilogue warps tcgen05.ld scores from TMEM into register fragments
- Compute per-row: tile_max, new_max, rescale = exp(old_max - new_max)
- Apply rescale to tmem_output in place (tmem_output *= rescale)
- Compute exp(scores - new_max), tcgen05.st back to TMEM as P operand for MMA2
- Update row_sum = row_sum * rescale + new_tile_sum
**The register fragment layout from tcgen05.ld is NOT (row, col).** It's determined by the MMA instruction's partition of the accumulator. Need to figure out the mapping from fragment indices to logical (head, kv_pos) positions for per-row softmax operations. fmha.py uses `tTMEM_LOADrS.load().reduce(cute.ReductionOp.MAX, row_max, 0)` for the row max — a built-in reduction that handles the layout.
### 🔨 Stage D: FP8 Paged KV Gather — AFTER C
Replace BF16 TMA load of KV with:
- Indexed cp.async gather from paged KV cache (fp8)
- Per-position dequant scale (inv_scale) applied during or after gather
- Keep KV in fp8 in SMEM, let the MMA's per-row scale handle dequant (like blockscaled GEMM)
### Architecture: Per-Tile Flow (from /root/fragile-kernel-example/README.md)
```
For each KV tile:
1. Load warp writes sKV[stage] (paged FP8 gather via indexed cp.async)
2. MMA warp issues MMA1: sQ @ sKV[stage]^T → tmem_scores (accumulate=False)
Signals scores_full_mbar (via PipelineUmmaAsync commit)
3. Epilogue warps wait on mma_si consumer (scores ready), then:
a. tcgen05.ld scores from TMEM → register fragments
b. Compute tile_max, new_max, rescale = exp(old_max - new_max)
c. Apply rescale to tmem_output IN PLACE (tmem_output *= rescale)
d. tcgen05.st exp(scores - new_max) back to TMEM → now it's the P operand
e. Release mma_si (softmax_done — MMA warp can re-acquire and issue PV MMA)
4. MMA warp waits on mma_si acquire (softmax done), then MMA2: P @ sKV[stage] → tmem_output (accumulate=True)
5. Stage released, load warp can refill it
After all tiles: epilogue warps tcgen05.ld tmem_output, divide by row_sum, cast to BF16, store to GMEM
```
### ✅ NVFP4 MoE (CuTeDSL) — WORKING
- `nvfp4_cutedsl.py` + `moe_pipeline.py`
- CuTeDSL NVFP4 Linear (q_a, kv, q_b, o_b) — cosine 0.994+
- CuTeDSL NVFP4 MoE (L1 gate+up, SiLU, L2 down) — cosine 0.988
- Fused SwiGLU epilogue (granularity-8 weight interleave) — cosine 0.988
### ✅ FP8 KV Quantize/Dequant — WORKING
- FP8 KV: cosine 0.9997
- NVFP4 KV: cosine 0.9943 (2x smaller than FP8)
- Paged KV cache read/write: cosine 1.0
### ❌ Sparse Decode Attention — NOT YET REWRITTEN
`native_sparse_decode.py` still has the scalar FMA bug. Needs the same tcgen05.mma rewrite.
### ✅ Full Attention Pipeline (standalone tests) — WORKING
- FP8 KV → full attention: cosine 0.9997
- CSA sparse attention (cr=4): works
- HCA sparse attention (cr=128): works
- Merged CSA+SWA attention: works
## Critical APIs & Lessons
### PipelineUmmaAsync consumer group size — THE MAY 20 BUG
**For `Agent.Thread` groups in `PipelineUmmaAsync`: use thread count, NOT warp count.**
```python
# WRONG (caused CUDA_ERROR_LAUNCH_FAILED):
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 4) # warp count
# CORRECT (matches fmha.py):
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(softmax_warp_ids)) # thread count
```
This applies to ALL PipelineUmmaAsync consumers where the consumer is multiple warps. fmha.py line 671: `self.threads_per_warp * len(self.softmax0_warp_ids) = 32 * 4 = 128`.
**Note:** The earlier README incorrectly stated that warp count was correct. That was wrong. The `Agent.Thread` agent type measures group size in threads.
### TMEM offset arithmetic
- `find_tmem_tensor_col_offset(fragment)` — returns physical TMEM column count (with 0x8000 tag for A-fragments)
- QK accumulator C fragment: 128 TMEM columns
- PV A-fragment: offset 0x8020 = tag(0x8000) + col(32) — the 0x8000 is a TMEM memory-space identifier
- P OVERLAPS S in TMEM — P is written at column 32 within the S region (C-layout columns 0..127)
- `tOrP0 = cute.make_tensor(tOrP.iterator + acc_dtype.width // q_dtype.width * tmem_p0_offset, tOrP.layout)` — A-fragment offset scaled by dtype width ratio (F32/BF16 = 2)
### `make_trivial_tiled_mma` has two overloads
```python
# New (preferred):
make_trivial_tiled_mma(a_dtype, b_dtype, a_leading_mode, b_leading_mode,
acc_dtype, cta_group, mma_tiler_mn, a_source=SMEM)
# Deprecated (still works, used by Stage A):
make_trivial_tiled_mma(ab_dtype, a_leading_mode, b_leading_mode,
acc_dtype, cta_group, mma_tiler_mn, a_source=SMEM)
```
### V SMEM aliasing (K and V share SMEM)
```python
# K and V share the same SMEM buffer, but with different layouts:
v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, b_dtype, 1)
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
tCrV = pv_mma.make_fragment_B(sV)
```
### Other APIs discovered from Stage A
1. **`cute.Tensor` API** — `cutlass_torch.from_dlpack(t).mark_layout_dynamic(leading_dim=...)`
2. **3D tensors** — Tensors must be 3D (M, K, L) for `cute.local_tile` — add L=1 dimension
3. **`PipelineTmaUmma.create(...).make_participants()`** — returns `(producer, consumer)` pair
4. **`utils.gemm.sm100.epilogue_tma_store`** — handles transform + partition/dcopy. DO NOT hand-roll.
5. **`get_num_tmem_alloc_cols`** — correct TMEM allocation (accepts list of fragments, sums cols, rounds to power of 2)
6. **`smem.allocate_tensor()`** — for SMEM tensors (not SharedStorage struct for A/B/C)
7. **`LayoutEnum.from_tensor(a).mma_major_mode()`** — major mode from cute tensor
8. **Minimum valid N tile for tcgen05.mma BF16**: 32 (step 32, range 32-256)
## Environment
- **Server**: root@45.76.247.107 (B200, 180 GiB HBM3e per GPU)
- **venv**: `source /root/dsv4-nvfp4-workspace/venv/bin/activate`
- **PYTHONPATH**: `/root/dsv4-nvfp4-workspace/kernel`
- **Model**: `/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4`
- **vLLM repo**: `/root/dsv4-nvfp4-workspace/vllm` (modified for Blackwell)
- **Pseudocode**: `/root/fragile-kernel-example/README.md` — authoritative per-tile attention flow
- **fmha.py reference**: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py`
## 4-Stage Build Plan
| Stage | Goal | Status |
|-------|------|--------|
| A | Bare Q@K^T via tcgen05.mma → TMEM → GMEM | ✅ COMPLETE |
| B | Two MMAs + identity softmax (validates TMEM A operand, shared KV, layout transform, barrier ordering) | 🔨 Runs without crash, identity softmax produces wrong output |
| C | Online softmax between MMA1 and MMA2 (the hard part) | ⬜ TODO |
| D | FP8 paged KV gather + dequant (replace BF16 TMA load) | ⬜ TODO |

View File

@@ -1,17 +1,18 @@
"""
Native CuTeDSL SWA Decode Attention Kernel for DeepSeek-V4 on Blackwell (SM100).
Blackwell SM100 Tensor-Core SWA Decode Attention Kernel for DeepSeek-V4.
FUSED kernel: paged KV read + Q*K^T + online softmax + V accumulation.
fp8 dequant is done in a batched pre-step on the host side (fast with torch ops).
Future optimization: fuse the fp8 dequant into the kernel using vectorized loads.
CTA mapping: one CTA per (decode_token, q_head_group).
- 128 Q heads / 16 per group = 8 groups per token
- Grid: (num_head_groups, num_decode_tokens, 1)
Architecture: Two GEMMs back-to-back sharing TMEM, softmax in registers between them.
Following dense_blockscaled_gemm_persistent.py for all Blackwell idioms:
- tcgen05.mma with TMEM accumulators
- TmemAllocator with holding buffer and dealloc barrier
- Warp specialization: 1 MMA warp + 2 epilogue warps
- Online softmax in epilogue warps between the two GEMMs
- Final normalize in epilogue (divide by row_sum)
"""
import torch
from typing import Optional
import math
try:
import cutlass
@@ -19,313 +20,356 @@ try:
import cutlass.torch as cutlass_torch
import cutlass.utils as utils
import cuda.bindings.driver as cuda
from cutlass.cute.nvgpu import tcgen05, warp
import cutlass.pipeline as pipeline
from cutlass.utils import blackwell_helpers as sm100_utils
from cutlass import BFloat16, Float32
HAS_CUTEDSL = True
except ImportError:
HAS_CUTEDSL = False
_compiled_kernel_cache = {}
HEAD_GROUP = 16
KV_TILE = 16
HEAD_DIM = 512
NUM_THREADS = 128
KV_TILE = 16
WINDOW_SIZE = 128
LOG2_E = 1.4426950408889634074
def native_swa_decode_attention(
q, swa_kv_cache, swa_inv_scale, swa_indices, swa_lens,
block_size, scale, window_size=128,
):
"""Native SWA decode attention.
Pre-dequantizes fp8 KV cache to bf16 in a batched operation,
then launches the CuTeDSL attention kernel on bf16 data.
"""
num_tokens, NH, HD = q.shape
device = q.device
if not HAS_CUTEDSL:
return _fallback_batched_sdp(q, swa_kv_cache, swa_inv_scale,
swa_indices, swa_lens, block_size,
scale, window_size)
q = q.contiguous()
swa_indices = swa_indices.contiguous()
swa_lens = swa_lens.contiguous()
# Pre-dequantize fp8 KV cache to bf16
# This is a batched gather + dequant: fast on GPU
if swa_indices.dim() == 3:
swa_indices_2d = swa_indices.squeeze(0)[:num_tokens]
else:
swa_indices_2d = swa_indices[:num_tokens]
swa_indices, swa_lens, block_size, scale, window_size)
q = q.contiguous(); swa_indices = swa_indices.contiguous(); swa_lens = swa_lens.contiguous()
if swa_indices.dim() == 3: swa_indices_2d = swa_indices.squeeze(0)[:num_tokens]
else: swa_indices_2d = swa_indices[:num_tokens]
max_len = swa_lens[:num_tokens].max().item()
if max_len <= 0:
return torch.zeros(num_tokens, NH, HD, dtype=torch.bfloat16, device=device)
# Clamp to window_size
if max_len <= 0: return torch.zeros(num_tokens, NH, HD, dtype=torch.bfloat16, device=device)
max_len = min(max_len, window_size)
# Gather all KV indices: (num_tokens, max_len)
safe_indices = swa_indices_2d[:, :max_len].clamp(min=0)
block_indices = safe_indices // block_size
offsets = safe_indices % block_size
# Batched gather + dequant
kv_raw = swa_kv_cache[block_indices, offsets] # (T, max_len, HD) fp8
if swa_kv_cache.dtype == torch.uint8:
kv_raw = kv_raw.view(torch.float8_e4m3fn)
inv_scales = swa_inv_scale[safe_indices] # (T, max_len, 1)
block_indices = safe_indices // block_size; offsets = safe_indices % block_size
kv_raw = swa_kv_cache[block_indices, offsets]
if swa_kv_cache.dtype == torch.uint8: kv_raw = kv_raw.view(torch.float8_e4m3fn)
inv_scales = swa_inv_scale[safe_indices]
kv_bf16 = (kv_raw.to(torch.bfloat16) * inv_scales).to(torch.bfloat16)
# Pad to window_size if needed
if max_len < window_size:
pad = torch.zeros(num_tokens, window_size - max_len, HD,
dtype=torch.bfloat16, device=device)
kv_bf16 = torch.cat([kv_bf16, pad], dim=1)
# kv_bf16 is now (num_tokens, window_size, HD) bf16
kv_bf16 = torch.cat([kv_bf16, torch.zeros(num_tokens, window_size-max_len, HD, dtype=torch.bfloat16, device=device)], dim=1)
output = torch.zeros(num_tokens, NH, HD, dtype=torch.bfloat16, device=device)
cache_key = (num_tokens, NH, HD, window_size, str(device))
if cache_key not in _compiled_kernel_cache:
def to_cute(t):
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
q_c = to_cute(q)
kv_c = to_cute(kv_bf16)
len_c = to_cute(swa_lens[:num_tokens])
out_c = to_cute(output)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
scale_tensor = torch.tensor([scale], dtype=torch.float32, device=device)
scale_c = to_cute(scale_tensor)
kernel = BlackwellSWADecodeKernel(
head_dim=HD, head_group=HEAD_GROUP, kv_tile=KV_TILE,
window_size=window_size,
)
compiled = cute.compile(
kernel, q_c, kv_c, len_c, out_c, scale_c, stream,
)
compiled(q_c, kv_c, len_c, out_c, scale_c, stream)
torch.cuda.synchronize()
_compiled_kernel_cache[cache_key] = {'compiled': compiled}
entry = _compiled_kernel_cache[cache_key]
compiled = entry['compiled']
def to_cute(t):
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
q_c = to_cute(q)
kv_c = to_cute(kv_bf16)
len_c = to_cute(swa_lens[:num_tokens])
out_c = to_cute(output)
def to_cute(t): return cutlass_torch.from_dlpack(t).mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
q_c, kv_c, len_c, out_c = to_cute(q), to_cute(kv_bf16), to_cute(swa_lens[:num_tokens]), to_cute(output)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
scale_tensor = torch.tensor([scale], dtype=torch.float32, device=device)
scale_c = to_cute(scale_tensor)
scale_c = to_cute(torch.tensor([scale], dtype=torch.float32, device=device))
kernel = BlackwellSWADecodeKernel(head_dim=HD, num_heads=NH, kv_tile=KV_TILE, window_size=window_size)
compiled = cute.compile(kernel, q_c, kv_c, len_c, out_c, scale_c, stream)
compiled(q_c, kv_c, len_c, out_c, scale_c, stream)
return output
def _fallback_batched_sdp(
q, swa_kv_cache, swa_inv_scale, swa_indices, swa_lens,
block_size, scale, window_size,
):
num_tokens, NH, HD = q.shape
device = q.device
if swa_indices.dim() == 3:
swa_indices = swa_indices.squeeze(0)
def _fallback_batched_sdp(q, swa_kv_cache, swa_inv_scale, swa_indices, swa_lens,
block_size, scale, window_size):
num_tokens, NH, HD = q.shape; device = q.device
if swa_indices.dim() == 3: swa_indices = swa_indices.squeeze(0)
safe_indices = swa_indices[:num_tokens].clamp(min=0)
block_indices = safe_indices // block_size
offsets = safe_indices % block_size
block_indices = safe_indices // block_size; offsets = safe_indices % block_size
kv_raw = swa_kv_cache[block_indices, offsets]
if swa_kv_cache.dtype == torch.uint8:
kv_raw = kv_raw.view(torch.float8_e4m3fn)
if swa_kv_cache.dtype == torch.uint8: kv_raw = kv_raw.view(torch.float8_e4m3fn)
inv_scales = swa_inv_scale[safe_indices]
kv_bf16 = (kv_raw.to(torch.bfloat16) * inv_scales).to(torch.bfloat16)
pos_range = torch.arange(window_size, device=device).unsqueeze(0)
len_mask = pos_range >= swa_lens[:num_tokens].unsqueeze(1)
invalid_mask = swa_indices[:num_tokens] < 0
invalid_mask = swa_indices[:num_tokens, :window_size] < 0
attn_mask = len_mask | invalid_mask
float_mask = torch.zeros(attn_mask.shape, dtype=torch.bfloat16, device=device)
float_mask[attn_mask] = float('-inf')
q_t = q.permute(1, 0, 2)
q_batch = q_t.reshape(NH * num_tokens, 1, HD)
kv_expanded = kv_bf16.unsqueeze(0).expand(NH, num_tokens, window_size, HD)
k_batch = kv_expanded.reshape(NH * num_tokens, window_size, HD)
v_batch = k_batch
mask_batch = float_mask.unsqueeze(0).unsqueeze(2).expand(
NH, num_tokens, 1, window_size
).reshape(NH * num_tokens, 1, window_size)
out = torch.nn.functional.scaled_dot_product_attention(
q_batch, k_batch, v_batch,
attn_mask=mask_batch,
is_causal=False,
scale=scale,
)
return out.reshape(NH, num_tokens, HD).permute(1, 0, 2)
kv_batch = kv_bf16.expand(NH, window_size, HD)
mask_batch = float_mask.unsqueeze(0).expand(NH, -1, -1)
out = torch.nn.functional.scaled_dot_product_attention(q_t, kv_batch, kv_batch, attn_mask=mask_batch, is_causal=False, scale=scale)
return out.permute(1, 0, 2)
if HAS_CUTEDSL:
class BlackwellSWADecodeKernel:
def __init__(self, head_dim=HEAD_DIM, head_group=HEAD_GROUP,
kv_tile=KV_TILE, window_size=128):
def __init__(self, head_dim=HEAD_DIM, num_heads=128,
kv_tile=KV_TILE, window_size=WINDOW_SIZE):
self._head_dim = head_dim
self._head_group = head_group
self._num_heads = num_heads
self._kv_tile = kv_tile
self._window_size = window_size
self._num_threads = NUM_THREADS
self._mma_m = 128
self._num_threads = 96
self._cta_group = tcgen05.CtaGroup.ONE
# Warp IDs: 0,1 = epilogue, 2 = MMA
self._mma_warp_id = 2
self._epi_warp_ids = [0, 1]
@cute.jit
def __call__(self, mQ, mKV, mLens, mO, mScale, stream):
num_tokens = mQ.shape[0]
num_head_groups = mQ.shape[1] // self._head_group
grid_dim = (num_head_groups, num_tokens, 1)
M = self._mma_m; HD = self._head_dim; KT = self._kv_tile
# TiledMma for Q @ K^T: Q(M,HD) x K^T(HD,KT) → S(M,KT)
tiled_mma_qk = sm100_utils.make_trivial_tiled_mma(
BFloat16, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.K,
Float32, self._cta_group, (M, KT))
# TiledMma for P @ V: P(M,KT) x V(KT,HD) → O(M,HD)
tiled_mma_pv = sm100_utils.make_trivial_tiled_mma(
BFloat16, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.MN,
Float32, self._cta_group, (M, KT),
tcgen05.OperandSource.TMEM)
# SMEM layouts
sA_layout_atom = cute.make_composed_layout(
cute.make_swizzle(3, 3, 3), 0,
cute.make_layout((8, 64), stride=(64, 1)))
sQ_layout = cute.tile_to_shape(sA_layout_atom, (M, HD), (0, 1))
sK_layout = cute.tile_to_shape(sA_layout_atom, (KT, HD), (0, 1))
sV_layout = cute.tile_to_shape(sA_layout_atom, (KT, HD), (0, 1))
sO_layout = cute.tile_to_shape(sA_layout_atom, (M, HD), (0, 1))
# Named barriers for TMEM allocation and MMA↔epilogue sync
tmem_alloc_barrier = pipeline.NamedBarrier(
barrier_id=2, num_threads=96) # all 3 warps
acc_full_barrier = pipeline.NamedBarrier(
barrier_id=3, num_threads=96) # all 3 warps
@cute.struct
class SharedStorage:
sQ: cute.struct.Align[cute.struct.MemRange[BFloat16, cute.cosize(sQ_layout)], 1024]
sK: cute.struct.Align[cute.struct.MemRange[BFloat16, cute.cosize(sK_layout)], 1024]
sV: cute.struct.Align[cute.struct.MemRange[BFloat16, cute.cosize(sV_layout)], 1024]
sO: cute.struct.Align[cute.struct.MemRange[BFloat16, cute.cosize(sO_layout)], 1024]
tmem_dealloc_mbar: cutlass.Int64
tmem_holding_buf: cutlass.Int32
self._kernel(
mQ, mKV, mLens, mO, mScale,
).launch(
grid=grid_dim,
block=[self._num_threads, 1, 1],
stream=stream,
)
sQ_layout, sK_layout, sV_layout, sO_layout,
tiled_mma_qk, tiled_mma_pv, SharedStorage,
tmem_alloc_barrier, acc_full_barrier,
).launch(grid=(1, num_tokens, 1), block=[self._num_threads, 1, 1], stream=stream)
@cute.kernel
def _kernel(
self,
mQ: cute.Tensor, # (T, NH, HD) bf16
mKV: cute.Tensor, # (T, WS, HD) bf16 - pre-dequantized
mLens: cute.Tensor, # (T,) int64
mO: cute.Tensor, # (T, NH, HD) bf16
mScale: cute.Tensor, # (1,) f32
):
def _kernel(self, mQ, mKV, mLens, mO, mScale,
sQ_layout, sK_layout, sV_layout, sO_layout,
tiled_mma_qk, tiled_mma_pv, SharedStorage: cutlass.Constexpr,
tmem_alloc_barrier, acc_full_barrier):
tidx, _, _ = cute.arch.thread_idx()
hg_idx, tok_idx, _ = cute.arch.block_idx()
_, tok_idx, _ = cute.arch.block_idx()
HG = self._head_group
HD = self._head_dim
KT = self._kv_tile
WS = self._window_size
M = self._mma_m; HD = self._head_dim; KT = self._kv_tile
softmax_scale = mScale[0]
swa_len = mLens[tok_idx]
# ── Shared memory ──────────────────────────────────────
@cute.struct
class SharedStorage:
kv_tile: cute.struct.MemRange[cutlass.BFloat16, KT * HD]
warp_idx = tidx // 32
is_mma_warp = warp_idx == self._mma_warp_id
is_epi_warp = warp_idx in self._epi_warp_ids
smem = utils.SmemAllocator()
storage = smem.allocate(SharedStorage)
sQ = storage.sQ.get_tensor(sQ_layout)
sK = storage.sK.get_tensor(sK_layout)
sV = storage.sV.get_tensor(sV_layout)
sO = storage.sO.get_tensor(sO_layout)
sKV = cute.make_tensor(
storage.kv_tile.data_ptr(),
cute.make_layout((KT, HD), stride=(HD, 1)),
# TMEM allocator (all warps participate in alloc/dealloc)
tmem = utils.TmemAllocator(
storage.tmem_holding_buf.ptr,
barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=self._epi_warp_ids[0],
)
# ── Read valid KV length ───────────────────────────────
swa_len = mLens[tok_idx]
has_kv = swa_len > 0
# Allocate TMEM for score and output accumulators
# Score: (M=128, KT=16) = 2048 FP32
# Output: (M=128, HD=512) = 65536 FP32
# Total: 67584 FP32 = 264 KB TMEM (within 1 MB budget)
num_tmem_cols = M * KT + M * HD # TMEM columns
tmem.allocate(num_tmem_cols)
# ── Load Q into registers: (HG, HD) ───────────────────
q_reg = cute.make_rmem_tensor((HG, HD), cutlass.BFloat16)
for h in cutlass.range_constexpr(HG):
qh = hg_idx * HG + h
for d in range(HD):
q_reg[h, d] = mQ[tok_idx, qh, d]
# TMEM layout for scores and output
# The TMEM layout comes from the TiledMma's C operand layout
tCtScores = tiled_mma_qk.make_fragment_C(tiled_mma_qk.partition_C_shape((M, KT)))
tCtOutput = tiled_mma_pv.make_fragment_C(tiled_mma_pv.partition_C_shape((M, HD)))
# ── Output accumulator: (HG, HD) f32 ──────────────────
acc_O = cute.make_rmem_tensor((HG, HD), cutlass.Float32)
acc_O.fill(0.0)
# Zero the output accumulator (MMA with ACCUMULATE=False does this for scores)
# For the output accumulator, we zero it before the KV loop
# TODO: zero tCtOutput via tcgen05.st or first PV GEMM with ACCUMULATE=False
# ── Online softmax state: (HG,) f32 ───────────────────
row_max = cute.make_rmem_tensor((HG,), cutlass.Float32)
row_sum = cute.make_rmem_tensor((HG,), cutlass.Float32)
row_max.fill(-1e30)
row_sum.fill(0.0)
# ── Stream KV tiles ────────────────────────────────────
max_tiles = (WS + KT - 1) // KT
for tile_idx in range(max_tiles):
tile_start = tile_idx * KT
# Load bf16 KV from contiguous tensor to smem
for kv_pos in range(KT):
global_kv = tile_start + kv_pos
# ─── MMA WARP ────────────────────────────────────────
if is_mma_warp:
# Load Q to SMEM (once, reused across all KV tiles)
for h in range(M):
for d in range(HD):
valid = global_kv < swa_len
val = cutlass.BFloat16(0.0)
if valid:
val = mKV[tok_idx, global_kv, d]
sKV[kv_pos, d] = val
sQ[h, d] = mQ[tok_idx, h, d]
cute.arch.sync_threads()
# Q * K^T: (HG, KT) scores
scores = cute.make_rmem_tensor((HG, KT), cutlass.Float32)
scores.fill(0.0)
# Partition Q and K for QK GEMM
thr_qk = tiled_mma_qk.get_slice(tidx)
tCrQ = thr_qk.make_fragment_A(thr_qk.partition_A(sQ))
tCrK = thr_qk.make_fragment_B(thr_qk.partition_B(sK))
for h in cutlass.range_constexpr(HG):
smem_copy_Q = cute.make_tiled_copy_A(
cute.make_copy_atom(warp.LdMatrix8x8x16bOp(False, 4), BFloat16), tiled_mma_qk)
thr_smem_Q = smem_copy_Q.get_slice(tidx)
tCsQ = thr_smem_Q.partition_S(sQ); tCrQ_cv = thr_smem_Q.retile(tCrQ)
smem_copy_K = cute.make_tiled_copy_B(
cute.make_copy_atom(warp.LdMatrix8x8x16bOp(False, 4), BFloat16), tiled_mma_qk)
thr_smem_K = smem_copy_K.get_slice(tidx)
tCsK = thr_smem_K.partition_S(sK); tCrK_cv = thr_smem_K.retile(tCrK)
# Load Q to registers (once)
for k in cutlass.range_constexpr(cute.size(tCsQ.shape[2])):
cute.copy(smem_copy_Q, tCsQ[None, None, k], tCrQ_cv[None, None, k])
# Partition V for PV GEMM
thr_pv = tiled_mma_pv.get_slice(tidx)
sVt = cute.composition(sV, cute.make_layout((HD, KT), stride=(KT, 1)))
tOrVt = thr_pv.make_fragment_B(thr_pv.partition_B(sVt))
smem_copy_V = cute.make_tiled_copy_B(
cute.make_copy_atom(warp.LdMatrix8x8x16bOp(True, 4), BFloat16), tiled_mma_pv)
thr_smem_V = smem_copy_V.get_slice(tidx)
tOsVt = thr_smem_V.partition_S(sVt); tOrVt_cv = thr_smem_V.retile(tOrVt)
# ── KV tile loop ─────────────────────────────────
n_block_max = (self._window_size + KT - 1) // KT
for n_block in range(n_block_max):
tile_start = n_block * KT
# Load K and V to SMEM
for kv_pos in range(KT):
dot = cutlass.Float32(0.0)
global_kv = tile_start + kv_pos
for d in range(HD):
q_val = q_reg[h, d].to(cutlass.Float32)
k_val = sKV[kv_pos, d].to(cutlass.Float32)
dot = dot + q_val * k_val
scores[h, kv_pos] = dot * softmax_scale
val = cutlass.BFloat16(0.0)
if global_kv < swa_len:
val = mKV[tok_idx, global_kv, d]
sK[kv_pos, d] = val
sV[kv_pos, d] = val
# Online softmax update
for h in cutlass.range_constexpr(HG):
tile_max = cutlass.Float32(-1e30)
for kv_pos in range(KT):
s = scores[h, kv_pos]
if s > tile_max:
tile_max = s
cute.arch.sync_threads()
new_max = row_max[h]
if tile_max > new_max:
new_max = tile_max
# ── Q @ K^T via tcgen05.mma ──────────────────
for k in cutlass.range_constexpr(cute.size(tCsK.shape[2])):
cute.copy(smem_copy_K, tCsK[None, None, k], tCrK_cv[None, None, k])
tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, k > 0)
cute.gemm(tiled_mma_qk, tCtScores, tCrQ[None, None, k], tCrK[None, None, k], tCtScores)
rescale = cutlass.Float32(0.0)
if row_max[h] > cutlass.Float32(-1e29):
rescale = cute.exp(row_max[h] - new_max)
# Signal epilogue: scores in TMEM are ready
cute.arch.fence_view_async_tmem_store()
acc_full_barrier.arrive()
for d in range(HD):
acc_O[h, d] = acc_O[h, d] * rescale
row_sum[h] = row_sum[h] * rescale
# Wait for epilogue to finish softmax
acc_full_barrier.wait()
for kv_pos in range(KT):
exp_score = cute.exp(scores[h, kv_pos] - new_max)
row_sum[h] = row_sum[h] + exp_score
for d in range(HD):
v_val = sKV[kv_pos, d].to(cutlass.Float32)
acc_O[h, d] = acc_O[h, d] + exp_score * v_val
# ── P @ V via tcgen05.mma ─────────────────────
# P from TMEM (softmax output), V from SMEM
for k in cutlass.range_constexpr(cute.size(tOsVt.shape[2])):
cute.copy(smem_copy_V, tOsVt[None, None, k], tOrVt_cv[None, None, k])
# For the first KV tile, first k tile: ACCUMULATE=False (zero the output)
# Otherwise, ACCUMULATE=True (accumulate into output)
is_first_output_tile = (n_block == 0) and (k == 0)
tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, not is_first_output_tile)
cute.gemm(tiled_mma_pv, tCtOutput, None, tOrVt[None, None, k], tCtOutput)
row_max[h] = new_max
cute.arch.fence_view_async_tmem_store()
acc_full_barrier.arrive()
cute.arch.sync_threads()
# Free TMEM
tmem.relinquish_alloc_permit()
acc_full_barrier.arrive() # Sync before dealloc
tmem.free(num_tmem_cols)
# ── Normalize and write output ─────────────────────────
for h in cutlass.range_constexpr(HG):
qh = hg_idx * HG + h
for d in range(HD):
val_f32 = cutlass.Float32(0.0)
if has_kv and row_sum[h] > cutlass.Float32(1e-30):
val_f32 = acc_O[h, d] / row_sum[h]
mO[tok_idx, qh, d] = val_f32.to(cutlass.BFloat16)
# ─── EPILOGUE WARPS ──────────────────────────────────
if is_epi_warp:
my_row_start = warp_idx * 64
num_my_rows = 64
# Online softmax state
row_max = cute.make_rmem_tensor((num_my_rows,), Float32)
row_sum = cute.make_rmem_tensor((num_my_rows,), Float32)
row_max.fill(-1e30)
row_sum.fill(0.0)
# TMEM→register copy for scores (tcgen05.ld pattern)
tiled_copy_t2r_scores = tcgen05.make_tmem_copy(
sm100_utils.get_tmem_load_op(self._cta_group), tCtScores)
# TMEM→register for output
tiled_copy_t2r_output = tcgen05.make_tmem_copy(
sm100_utils.get_tmem_load_op(self._cta_group), tCtOutput)
# Register fragments
tRrScores = cute.make_fragment_like(
tiled_copy_t2r_scores.partition_D(
cute.make_tensor(tCtScores.iterator, tCtScores.layout)), Float32)
tRrOutput = cute.make_fragment_like(
tiled_copy_t2r_output.partition_D(
cute.make_tensor(tCtOutput.iterator, tCtOutput.layout)), Float32)
n_block_max = (self._window_size + KT - 1) // KT
for n_block in range(n_block_max):
# Wait for MMA to finish Q@K^T
acc_full_barrier.wait()
# ── tcgen05.ld scores from TMEM to registers ──
cute.copy(tiled_copy_t2r_scores, tCtScores, tRrScores)
cute.arch.fence_view_async_tmem_load()
# ── Softmax in registers ──────────────────────
# For each row this warp owns (64 rows per warp):
# 1. tile_max = max(scores * scale) — reduce across KT positions
# 2. new_max = max(row_max_prev, tile_max)
# 3. prev_exp = exp((row_max_prev - new_max) * scale * log2e)
# 4. Rescale output accumulator in TMEM: O *= prev_exp
# (via tcgen05.st with scaled values, or defer to PV GEMM)
# 5. row_sum *= prev_exp
# 6. exp_scores = exp((scores * scale - new_max * scale) * log2e)
# 7. row_sum += sum(exp_scores)
# 8. Write exp_scores back to TMEM (as P operand for PV GEMM)
# 9. row_max = new_max
# The register fragment layout after tcgen05.ld is
# (EPI_TILE_M=128, EPI_TILE_N=16) partitioned per epilogue warp.
# Each epilogue warp's fragment covers 64 rows and 16 columns.
# We can iterate over rows and compute softmax per row.
# TODO: Implement per-row softmax on the register fragment.
# This requires understanding the exact tRrScores layout
# from tcgen05.make_tmem_copy + partition_D.
# The dense GEMM's epilogue shows the pattern for iterating
# over the register fragment.
# For now, write the P values (softmax output) back to TMEM
# via tcgen05.st (register → TMEM copy)
# tcgen05.st: tRrScores → tCtScores
cute.copy(tiled_copy_t2r_scores, tRrScores, tCtScores)
cute.arch.fence_view_async_tmem_store()
# Signal MMA: softmax done, P in TMEM ready for PV GEMM
acc_full_barrier.arrive()
# Wait for MMA to finish P@V
acc_full_barrier.wait()
# ── Final normalize ──────────────────────────────
# tcgen05.ld output from TMEM, divide by row_sum, store to GMEM
cute.copy(tiled_copy_t2r_output, tCtOutput, tRrOutput)
cute.arch.fence_view_async_tmem_load()
# Divide each element by row_sum
# TODO: implement per-row normalization on register fragment
# Cast to BF16 and store to SMEM, then GMEM
# (following dense GEMM's epilogue pattern)
# Relinquish TMEM
tmem.relinquish_alloc_permit()
acc_full_barrier.arrive()
tmem.free(num_tmem_cols)

59
tests/debug_stages.py Normal file
View File

@@ -0,0 +1,59 @@
"""
Debug: test each stage independently.
1. Run Stage A (Q @ K^T only) — should give cosine 0.999
2. Run Stage B minimal (two MMAs, no softmax) — should give NaN or garbage
3. Run Stage B pipeline-only (pipeline but no ld/st) — should give NaN or garbage
4. Run Stage B full (identity softmax) — should give correct (Q@K^T)@V
"""
import torch
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
ref_qkt = qf @ kvf.T
ref_qktv = ref_qkt @ kvf
print(f"Q shape: {q.shape}, KV shape: {kv.shape}")
print(f"Q@K^T shape: {ref_qkt.shape}, (Q@K^T)@V shape: {ref_qktv.shape}")
print(f"Q@K^T range: [{ref_qkt.min():.2f}, {ref_qkt.max():.2f}]")
print(f"(Q@K^T)@V range: [{ref_qktv.min():.2f}, {ref_qktv.max():.2f}]")
# Test Stage A first
from test_stage_a_v2 import StageAQKTKernel
c_a = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c_a).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_a))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel_a = StageAQKTKernel(mma_tiler_mn=(128, 128))
compiled_a = cute.compile(kernel_a, mQ, mK, mC, stream)
compiled_a(mQ, mK, mC, stream)
torch.cuda.synchronize()
out_a = c_a[:,:,0].float()
cos_a = torch.nn.functional.cosine_similarity(out_a.flatten().unsqueeze(0), ref_qkt.flatten().unsqueeze(0)).item()
print(f"\nStage A (Q@K^T): cosine = {cos_a:.6f} {'' if cos_a > 0.99 else ''}")
# Test Stage B v7 (identity softmax)
from test_stage_b_v7 import StageBIdentitySoftmax
c_b = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
mC2 = ct.from_dlpack(c_b).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_b))
kernel_b = StageBIdentitySoftmax(mma_tiler_mn=(128, 128))
compiled_b = cute.compile(kernel_b, mQ, mK, mC2, stream)
compiled_b(mQ, mK, mC2, stream)
torch.cuda.synchronize()
out_b = c_b[:,:,0].float()
cos_b = torch.nn.functional.cosine_similarity(out_b.flatten().unsqueeze(0), ref_qktv.flatten().unsqueeze(0)).item()
has_nan = torch.isnan(out_b).any().item()
print(f"Stage B (identity softmax): cosine = {cos_b:.6f}, has_nan = {has_nan} {'' if cos_b > 0.99 else ''}")
# Check: is the output close to Q@K^T (not Q@K^T@V)?
cos_b_qkt = torch.nn.functional.cosine_similarity(out_b.flatten().unsqueeze(0), ref_qkt.flatten().unsqueeze(0)).item()
print(f" vs Q@K^T: cosine = {cos_b_qkt:.6f} (should be ~0 if it's Q@K^T@V)")
print(f" Output range: [{out_b.nan_to_num().min():.2f}, {out_b.nan_to_num().max():.2f}]")

91
tests/diag_tmem.py Normal file
View File

@@ -0,0 +1,91 @@
"""Diagnostic: Q1, Q2 for Stage B TMEM debugging.
Uses cute.compile with a dummy kernel that prints layout info at JIT time."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
@cute.jit
def diag_tmem(stream: cuda.CUstream):
a_dtype = BFloat16; b_dtype = BFloat16
a_major = cute.nvgpu.OperandMajorMode.K
b_major = cute.nvgpu.OperandMajorMode.K
qk_mma = utils.sm100.make_trivial_tiled_mma(
a_dtype, b_dtype, a_major, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
a_dtype, b_dtype, cute.nvgpu.OperandMajorMode.K, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.TMEM)
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
mma_tiler = (128, 128, qk_inst_k * 4)
pv_mma_tiler = (128, 128, pv_inst_k * 4)
qk_thr = qk_mma.get_slice(0)
pv_thr = pv_mma.get_slice(0)
# Q1: QK accumulator C fragment
qk_acc_shape = qk_thr.partition_shape_C(mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
print(f"=== Q1: QK accumulator C fragment ===")
print(f" tStS.layout = {tStS.layout}")
print(f" cute.size(tStS.layout) = {cute.size(tStS.layout)}")
print(f" cute.cosize(tStS.layout) = {cute.cosize(tStS.layout)}")
print(f" cute.size(mode=[0]) = {cute.size(tStS.layout, mode=[0])}")
print(f" cute.size(mode=[1]) = {cute.size(tStS.layout, mode=[1])}")
s_tmem_cols = find_tmem_tensor_col_offset(tStS)
print(f" find_tmem_tensor_col_offset(tStS) = {s_tmem_cols}")
# PV accumulator O fragment
pv_acc_shape = pv_thr.partition_shape_C(mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
print(f"=== PV accumulator O fragment ===")
print(f" tOtO.layout = {tOtO.layout}")
print(f" cute.size(tOtO.layout) = {cute.size(tOtO.layout)}")
print(f" cute.cosize(tOtO.layout) = {cute.cosize(tOtO.layout)}")
print(f" cute.size(mode=[0]) = {cute.size(tOtO.layout, mode=[0])}")
print(f" cute.size(mode=[1]) = {cute.size(tOtO.layout, mode=[1])}")
o_tmem_cols = find_tmem_tensor_col_offset(tOtO)
print(f" find_tmem_tensor_col_offset(tOtO) = {o_tmem_cols}")
# Q2: PV A-fragment (P operand from TMEM)
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP_sliced = tOrP_base[(None, None, None, 0)]
print(f"=== Q2: PV A-fragment (P operand) ===")
print(f" tP.layout = {tP.layout}")
print(f" cute.size(tP.layout) = {cute.size(tP.layout)}")
print(f" cute.cosize(tP.layout) = {cute.cosize(tP.layout)}")
print(f" tOrP_sliced.layout = {tOrP_sliced.layout}")
print(f" cute.size(tOrP_sliced.layout) = {cute.size(tOrP_sliced.layout)}")
print(f" cute.cosize(tOrP_sliced.layout) = {cute.cosize(tOrP_sliced.layout)}")
p_tmem_cols = find_tmem_tensor_col_offset(tOrP_sliced)
print(f" find_tmem_tensor_col_offset(tOrP_sliced) = {p_tmem_cols}")
# Decompose 32800
print(f" 32800 in hex = 0x{32800:04x}")
print(f" 32800 - 0x8000 = {32800 - 0x8000}")
print(f" 32800 & 0x0000FFFF = {32800 & 0x0000FFFF}")
print(f" p_tmem_cols in hex = 0x{p_tmem_cols:04x}")
if isinstance(p_tmem_cols, int):
print(f" p_tmem_cols & 0x0000FFFF = {p_tmem_cols & 0x0000FFFF}")
print(f" p_tmem_cols >> 16 = {p_tmem_cols >> 16}")
# Staged fragments
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
print(f"=== Staged fragments ===")
print(f" find_tmem_tensor_col_offset(tCtS_fake) = {find_tmem_tensor_col_offset(tCtS_fake)}")
print(f" find_tmem_tensor_col_offset(tCtO_fake) = {find_tmem_tensor_col_offset(tCtO_fake)}")
print(f" get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake]) = {utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch='sm_100')}")
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
print("Compiling diagnostics...", flush=True)
compiled = cute.compile(diag_tmem, stream)
print("Done. Results above.", flush=True)

187
tests/stage_b_debug5.py Normal file
View File

@@ -0,0 +1,187 @@
"""Stage B debug v5: Minimal two-MMA with debug printf to find deadlock location."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
class StageBDebug5:
def __init__(self, mma_tiler_mn):
self.acc_dtype = Float32; self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE; self.use_2cta_instrs = False
self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.epilog_sync_bar_id = 1; self.num_c_stage = 2
def _setup(self, qk_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = tuple(self.qk_mma_tiler)
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, BFloat16)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, BFloat16, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, BFloat16, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(BFloat16, self.c_layout, self.epi_tile, 2)
# Use QK fragment for tmem allocation
acc_shape = qk_mma.partition_shape_C(self.mma_tiler_mn)
tCtAcc_fake = qk_mma.make_fragment_C(cute.append(acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake, arch="sm_100")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (cute.size_in_bytes(BFloat16, a_smem) + cute.size_in_bytes(BFloat16, b_smem)) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a, b, c, stream):
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.a_major, self.b_major, self.acc_dtype, self.cta_group, self.mma_tiler_mn)
self._setup(qk_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[192,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=160)
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=0, is_two_cta=False,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=BFloat16, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=BFloat16, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sC = smem.allocate_tensor(element_type=BFloat16, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
# NO pv_mma.make_fragment_B
acc_shape = qk_mma.partition_shape_C(self.mma_tiler_mn)
tCtAcc_fake = qk_mma.make_fragment_C(cute.append(acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# MMA — identical to Stage A
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
tCtAcc = tCtAcc_base[(None,None,None,0)]
ab_c.reset(); peek = ab_c.try_wait()
acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, 1)
acc_pipe.producer_acquire(acc_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tCtAcc, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tCtAcc)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
acc_pipe.producer_commit(acc_st)
acc_st.advance()
acc_pipe.producer_tail(acc_st)
# Epilogue — identical to Stage A
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
cons = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 128)
c_pipe = pipeline.PipelineTmaStore.create(num_stages=2, producer_group=c_grp)
cons = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtAcc_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), cons, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m,n,k = 128,128,512
a = torch.randn(m,k,1,dtype=torch.bfloat16,device='cuda')
b = torch.randn(n,k,1,dtype=torch.bfloat16,device='cuda')
c = torch.zeros(m,n,1,dtype=torch.bfloat16,device='cuda')
import cutlass.torch as ct
mA = ct.from_dlpack(a).mark_layout_dynamic(leading_dim=ct.get_leading_dim(a))
mB = ct.from_dlpack(b).mark_layout_dynamic(leading_dim=ct.get_leading_dim(b))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBDebug5((128,128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mA, mB, mC, stream)
print('Running...', flush=True)
compiled(mA, mB, mC, stream)
torch.cuda.synchronize()
print('No deadlock!')
if __name__ == '__main__':
test()

372
tests/test_stage_a_copy.py Normal file
View File

@@ -0,0 +1,372 @@
"""
Stage A: Bare Q@K^T via tcgen05.mma → TMEM → GMEM
Follows the CUTLASS dense_gemm_persistent.py pattern EXACTLY.
BF16 inputs, FP32 accumulator, TMA load/store, warp specialization.
Single tile (no persistent scheduler), cluster (1,1).
"""
import torch
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.cute.runtime import make_ptr
import cuda.bindings.driver as cuda
class StageAQKTKernel:
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
self.acc_dtype = Float32
self.use_2cta_instrs = use_2cta_instrs
self.mma_tiler_mn = mma_tiler_mn
self.mma_tiler = (*mma_tiler_mn, 1)
self.use_tma_store = use_tma_store
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.threads_per_cta = 32 * 6 # 192
self.epilog_sync_bar_id = 1
self.tmem_alloc_sync_bar_id = 2
self.tmem_dealloc_sync_bar_id = 3
def _create_tiled_mma(self):
return utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.a_major_mode, self.b_major_mode,
self.acc_dtype, self.cta_group, self.mma_tiler_mn,
)
def _setup_attributes(self):
tiled_mma = self._create_tiled_mma()
mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
mma_inst_tile_k = 4
self.mma_tiler = (self.mma_tiler[0], self.mma_tiler[1], mma_inst_shape_k * mma_inst_tile_k)
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler[1],
self.mma_tiler[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((1, 1, 1)), (tiled_mma.thr_id.shape,))
self.num_mcast_ctas_a = 1
self.num_mcast_ctas_b = 1
self.is_a_mcast = False
self.is_b_mcast = False
# Epilogue tile
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.c_dtype)
# Stage counts: 1 AB stage (single tile, no double-buffer), 1 acc stage, 2 C stages
self.num_ab_stage = 1
self.num_acc_stage = 1
self.num_c_stage = 2
# SMEM layouts
self.a_smem_layout_staged = utils.sm100.make_smem_layout_a(
tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage)
self.b_smem_layout_staged = utils.sm100.make_smem_layout_b(
tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage)
self.c_smem_layout_staged = utils.sm100.make_smem_layout_epi(
self.c_dtype, self.c_layout, self.epi_tile, self.num_c_stage)
# TMEM alloc cols
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake, arch="sm_100")
# TMA load bytes
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.a_dtype, a_smem_layout) +
cute.size_in_bytes(self.b_dtype, b_smem_layout)
) * cute.size(tiled_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor,
stream: cuda.CUstream):
self.a_dtype = a.element_type
self.b_dtype = b.element_type
self.c_dtype = c.element_type
self.a_major_mode = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major_mode = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
tiled_mma = self._create_tiled_mma()
self._setup_attributes()
# TMA load A
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id),
a, a_smem_layout, self.mma_tiler, tiled_mma,
self.cluster_layout_vmnk.shape,
)
# TMA load B
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id),
b, b_smem_layout, self.mma_tiler, tiled_mma,
self.cluster_layout_vmnk.shape,
)
# TMA store C
epi_smem_layout = cute.select(self.c_smem_layout_staged, mode=[0, 1])
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem_layout, self.epi_tile)
self._kernel(
tiled_mma, tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b,
tma_atom_c, tma_tensor_c, self.cluster_layout_vmnk,
self.a_smem_layout_staged, self.b_smem_layout_staged,
self.c_smem_layout_staged, self.epi_tile,
).launch(grid=(1, 1, 1), block=[self.threads_per_cta, 1, 1], stream=stream)
@cute.kernel
def _kernel(self, tiled_mma, tma_atom_a, mA_mkl, tma_atom_b, mB_nkl,
tma_atom_c, mC_mnl, cluster_layout_vmnk,
a_smem_layout_staged, b_smem_layout_staged, c_smem_layout_staged, epi_tile):
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
tidx, _, _ = cute.arch.thread_idx()
use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
is_leader_cta = True # single CTA, always leader
# Prefetch TMA descriptors
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_atom_a)
cpasync.prefetch_descriptor(tma_atom_b)
cpasync.prefetch_descriptor(tma_atom_c)
# ── Shared storage ───────────────────────────────────
@cute.struct
class SharedStorage:
ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc_mbar: cutlass.Int64
tmem_holding_buf: cutlass.Int32
smem = utils.SmemAllocator()
storage = smem.allocate(SharedStorage)
# AB pipeline
ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes,
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
).make_participants()
# ACC pipeline
acc_pipeline = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta_instrs else 1)),
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
)
# TMEM allocator
tmem_alloc_barrier = pipeline.NamedBarrier(
barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)),
)
tmem = utils.TmemAllocator(
storage.tmem_holding_buf.ptr,
barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=self.epilogue_warp_id[0],
is_two_cta=use_2cta_instrs,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
)
pipeline.pipeline_init_arrive(cluster_shape_mn=cluster_layout_vmnk, is_relaxed=True)
# SMEM tensors
sA = smem.allocate_tensor(
element_type=self.a_dtype, layout=a_smem_layout_staged.outer,
byte_alignment=128, swizzle=a_smem_layout_staged.inner)
sB = smem.allocate_tensor(
element_type=self.b_dtype, layout=b_smem_layout_staged.outer,
byte_alignment=128, swizzle=b_smem_layout_staged.inner)
sC = smem.allocate_tensor(
element_type=self.c_dtype, layout=c_smem_layout_staged.outer,
byte_alignment=128, swizzle=c_smem_layout_staged.inner)
# Partition global tensors
gA_mkl = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None))
gB_nkl = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None))
gC_mnl = cute.local_tile(mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None))
k_tile_cnt = cute.size(gA_mkl, mode=[3])
# Partition for TiledMMA
thr_mma = tiled_mma.get_slice(0) # leader CTA
tCgA = thr_mma.partition_A(gA_mkl)
tCgB = thr_mma.partition_B(gB_nkl)
tCgC = thr_mma.partition_C(gC_mnl)
# TMA partition A/B
a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape)
tAsA, tAgA = cpasync.tma_partition(
tma_atom_a, 0, a_cta_layout,
cute.group_modes(sA, 0, 3), cute.group_modes(tCgA, 0, 3))
b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape)
tBsB, tBgB = cpasync.tma_partition(
tma_atom_b, 0, b_cta_layout,
cute.group_modes(sB, 0, 3), cute.group_modes(tCgB, 0, 3))
# Slice to tile coord (0, 0, 0)
tAgA_slice = tAgA[(None, 0, None, 0)]
tBgB_slice = tBgB[(None, 0, None, 0)]
# MMA fragments
tCrA = tiled_mma.make_fragment_A(sA)
tCrB = tiled_mma.make_fragment_B(sB)
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cluster_layout_vmnk)
# ══════════════════════════════════════════════════════════
# TMA LOAD WARP (warp 5)
# ══════════════════════════════════════════════════════════
if warp_idx == self.tma_warp_id:
ab_producer.reset()
peek_ab_empty_status = ab_producer.try_acquire()
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
handle = ab_producer.acquire_and_advance(peek_ab_empty_status)
cute.copy(tma_atom_a, tAgA_slice[(None, handle.count)], tAsA[(None, handle.index)],
tma_bar_ptr=handle.barrier)
cute.copy(tma_atom_b, tBgB_slice[(None, handle.count)], tBsB[(None, handle.index)],
tma_bar_ptr=handle.barrier)
peek_ab_empty_status = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt:
peek_ab_empty_status = ab_producer.try_acquire()
ab_producer.tail()
# ══════════════════════════════════════════════════════════
# MMA WARP (warp 4)
# ══════════════════════════════════════════════════════════
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
tCtAcc = tCtAcc_base[(None, None, None, 0)]
ab_consumer.reset()
peek_ab_full_status = cutlass.Boolean(1)
if is_leader_cta:
peek_ab_full_status = ab_consumer.try_wait()
acc_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_acc_stage)
if is_leader_cta:
acc_pipeline.producer_acquire(acc_producer_state)
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
for k_tile in range(k_tile_cnt):
if is_leader_cta:
handle = ab_consumer.wait_and_advance(peek_ab_full_status)
num_kblocks = cute.size(tCrA, mode=[2])
for kblk_idx in cutlass.range(num_kblocks, unroll_full=True):
kblk_crd = (None, None, kblk_idx, handle.index)
cute.gemm(tiled_mma, tCtAcc, tCrA[kblk_crd], tCrB[kblk_crd], tCtAcc)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
handle.release()
peek_ab_full_status = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt:
peek_ab_full_status = ab_consumer.try_wait()
if is_leader_cta:
acc_pipeline.producer_commit(acc_producer_state)
acc_producer_state.advance()
acc_pipeline.producer_tail(acc_producer_state)
# ══════════════════════════════════════════════════════════
# EPILOGUE WARPS (0..3)
# ══════════════════════════════════════════════════════════
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
acc_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipeline = pipeline.PipelineTmaStore.create(
num_stages=self.num_c_stage, producer_group=c_producer_group)
# Use the reference epilogue implementation
mma_tile_coord_mnl = (0, 0, 0)
epilogue_op = const_expr(lambda x: x)
num_tiles_executed = 0
acc_consumer_state = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_atom_c, tCtAcc_base, sC, tCgC,
epi_tile, num_tiles_executed, epilogue_op,
mma_tile_coord_mnl, acc_consumer_state, acc_pipeline, c_pipeline)
c_pipeline.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test_stage_a():
"""Test Stage A: Q @ K^T → TMEM → GMEM"""
device = torch.device("cuda")
torch.manual_seed(42)
m, n, k = 128, 128, 512
# Tensors must be 3D (M, K, L) for the CUTLASS pattern
a = torch.randn(m, k, 1, dtype=torch.bfloat16, device="cuda")
b = torch.randn(n, k, 1, dtype=torch.bfloat16, device="cuda")
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device="cuda")
ref = a[:, :, 0].float() @ b[:, :, 0].float().T
# Create cute tensors
import cutlass.torch as cutlass_torch
mA = cutlass_torch.from_dlpack(a).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(a))
mB = cutlass_torch.from_dlpack(b).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(b))
mC = cutlass_torch.from_dlpack(c).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageAQKTKernel(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
compiled = cute.compile(kernel, mA, mB, mC, stream)
# Run with the same tensors
compiled(mA, mB, mC, stream)
torch.cuda.synchronize()
output = c[:, :, 0].float()
cos = torch.nn.functional.cosine_similarity(
output.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (output - ref).abs().max().item()
print("Stage A: Q({},{}) @ K^T({}, {}) -> S({}, {})".format(m, k, k, n, m, n))
print(" Cosine: {:.6f}, Max error: {:.6f}".format(cos, max_err))
print(" {}".format("PASS" if cos >= 0.99 else "FAIL"))
return cos
if __name__ == "__main__":
test_stage_a()

View File

@@ -0,0 +1,395 @@
"""
Stage A Minimal: Just tcgen05.mma with TMEM accumulator, epilogue writes to GMEM.
No TMA loads (use regular SMEM loads). No pipeline. Just validate the MMA path.
"""
import torch
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import tcgen05
from cutlass import BFloat16, Float32
import cuda.bindings.driver as cuda
# Q: (128, 512) BF16, K^T: (64, 512) BF16 -> S: (128, 64) FP32
# These are the minimum valid tile sizes for tcgen05.mma
M = 128
N = 64
K = 512
class MinimalMMAKernel:
def __init__(self):
self.cta_group = tcgen05.CtaGroup.ONE
self.mma_tiler_mn = (M, N)
self.threads_per_cta = 192 # 6 warps
self.epilog_warp_ids = [0, 1, 2, 3]
self.mma_warp_id = 4
self.tma_warp_id = 5
@cute.jit
def __call__(self, a_ptr, b_ptr, c_ptr, problem_m, problem_n, problem_k, stream):
a_dtype = a_ptr.value_type
b_dtype = b_ptr.value_type
c_dtype = c_ptr.value_type
acc_dtype = Float32
m, n, k = problem_m, problem_n, problem_k
# TiledMMA
tiled_mma = sm100_utils.make_trivial_tiled_mma(
a_dtype, b_dtype,
tcgen05.OperandMajorMode.K,
tcgen05.OperandMajorMode.K,
acc_dtype,
self.cta_group,
self.mma_tiler_mn,
)
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
self.atom_thr_size = atom_thr_size
mma_tiler = (self.mma_tiler_mn[0], self.mma_tiler_mn[1], cute.size(tiled_mma.shape_mnk, mode=[2]) * 4)
self.mma_tiler = mma_tiler
cta_tile_shape_mnk = (mma_tiler[0] // atom_thr_size, mma_tiler[1], mma_tiler[2])
self.cta_tile_shape_mnk = cta_tile_shape_mnk
# SMEM layouts
num_ab_stages = 1
a_smem_layout = sm100_utils.make_smem_layout_a(tiled_mma, mma_tiler, a_dtype, num_ab_stages)
b_smem_layout = sm100_utils.make_smem_layout_b(tiled_mma, mma_tiler, b_dtype, num_ab_stages)
# Epilogue tile
c_layout_enum = utils.LayoutEnum.ROW_MAJOR
epi_tile = sm100_utils.compute_epilogue_tile_shape(
cta_tile_shape_mnk, False, c_layout_enum, c_dtype)
self.epi_tile = epi_tile
c_smem_layout = sm100_utils.make_smem_layout_epi(c_dtype, c_layout_enum, epi_tile, 2)
self.c_smem_layout = c_smem_layout
# TMEM columns
self.num_accumulator_tmem_cols = cta_tile_shape_mnk[1]
# GMEM tensors
a_gmem_layout = cute.make_ordered_layout((m, k), order=(1, 0))
b_gmem_layout = cute.make_ordered_layout((n, k), order=(1, 0))
c_gmem_layout = cute.make_ordered_layout((m, n), order=(1, 0))
gA = cute.make_tensor(a_ptr, a_gmem_layout)
gB = cute.make_tensor(b_ptr, b_gmem_layout)
gC = cute.make_tensor(c_ptr, c_gmem_layout)
# TMA descriptors
a_smem_layout_one = cute.slice_(a_smem_layout, (None, None, None, 0))
b_smem_layout_one = cute.slice_(b_smem_layout, (None, None, None, 0))
c_smem_layout_one = cute.slice_(c_smem_layout, (None, None, 0))
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
sm100_utils.cluster_shape_to_tma_atom_A((1, 1), tiled_mma.thr_id),
gA, a_smem_layout_one, mma_tiler, tiled_mma,
(1, 1, 1),
)
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
sm100_utils.cluster_shape_to_tma_atom_B((1, 1), tiled_mma.thr_id),
gB, b_smem_layout_one, mma_tiler, tiled_mma,
(1, 1, 1),
)
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(),
gC, c_smem_layout_one, epi_tile,
)
# Pipeline barriers
a_copy_size = cute.size_in_bytes(a_dtype, a_smem_layout_one)
b_copy_size = cute.size_in_bytes(b_dtype, b_smem_layout_one)
tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size
self.tma_load_bytes = tma_load_bytes
# Named barriers
self.epilog_sync_barrier = pipeline.NamedBarrier(
barrier_id=1,
num_threads=32 * len(self.epilog_warp_ids),
)
self.tmem_alloc_barrier = pipeline.NamedBarrier(
barrier_id=2,
num_threads=32 * (1 + len(self.epilog_warp_ids)),
)
@cute.struct
class SharedStorage:
ab_full_mbar: cute.struct.MemRange[cutlass.Int64, num_ab_stages]
ab_empty_mbar: cute.struct.MemRange[cutlass.Int64, num_ab_stages]
acc_full_mbar: cute.struct.MemRange[cutlass.Int64, 1]
acc_empty_mbar: cute.struct.MemRange[cutlass.Int64, 1]
tmem_dealloc_mbar: cutlass.Int64
tmem_holding_buf: cutlass.Int32
sA: cute.struct.Align[cute.struct.MemRange[a_dtype, cute.cosize(a_smem_layout.outer)], 1024]
sB: cute.struct.Align[cute.struct.MemRange[b_dtype, cute.cosize(b_smem_layout.outer)], 1024]
sC: cute.struct.Align[cute.struct.MemRange[c_dtype, cute.cosize(c_smem_layout.outer)], 1024]
self.shared_storage = SharedStorage
# Cluster
self.cluster_shape_mn = (1, 1)
cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1, 1, 1)), (tiled_mma.thr_id.shape,))
self.cluster_layout_vmnk = cluster_layout_vmnk
self._kernel(
tiled_mma, tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b,
tma_atom_c, tma_tensor_c, cluster_layout_vmnk,
a_smem_layout, b_smem_layout, c_smem_layout, epi_tile,
gA, gB, gC, mma_tiler,
).launch(grid=(1, 1, 1), block=[self.threads_per_cta, 1, 1], stream=stream)
@cute.kernel
def _kernel(self, tiled_mma, tma_atom_a, mA, tma_atom_b, mB,
tma_atom_c, mC, cluster_layout_vmnk,
a_smem_layout, b_smem_layout, c_smem_layout, epi_tile,
gA, gB, gC, mma_tiler):
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(tiled_mma.thr_id.shape) == 2
smem = utils.SmemAllocator()
storage = smem.allocate(self.shared_storage)
# AB pipeline
ab_pipeline = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.ab_full_mbar.data_ptr(),
num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.tma_load_bytes,
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
)
# Accumulator pipeline
acc_pipeline = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_full_mbar.data_ptr(),
num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread,
32 * len(self.epilog_warp_ids) * (2 if use_2cta else 1)),
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
)
# TMEM allocator
tmem = utils.TmemAllocator(
storage.tmem_holding_buf.ptr,
barrier_for_retrieve=self.tmem_alloc_barrier,
allocator_warp_id=self.epilog_warp_ids[0],
is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
)
pipeline.pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True)
# SMEM tensors
sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
sC = storage.sC.get_tensor(c_smem_layout.outer, swizzle=c_smem_layout.inner)
# gC tiled for epilogue partition
gC_tiled = cute.local_tile(gC, (mma_tiler[0], mma_tiler[1]), (0, 0))
# Partition for TMA — use TMA tensors directly (they have static shape)
thr_mma = tiled_mma.get_slice(0)
tCgC = thr_mma.partition_C(gC_tiled)
tAsA, tAgA = cpasync.tma_partition(
tma_atom_a, 0, cute.make_layout(1),
cute.group_modes(sA, 0, 3),
cute.group_modes(mA, 0, 3),
)
tBsB, tBgB = cpasync.tma_partition(
tma_atom_b, 0, cute.make_layout(1),
cute.group_modes(sB, 0, 3),
cute.group_modes(mB, 0, 3),
)
# MMA fragments
tCrA = tiled_mma.make_fragment_A(sA)
tCrB = tiled_mma.make_fragment_B(sB)
# TMEM accumulator
acc_shape = tiled_mma.partition_shape_C(mma_tiler[:2])
tCtAcc_all = tiled_mma.make_fragment_C(cute.append(acc_shape, 1))
k_tile_cnt = cute.size(mA, mode=[3])
# gC tiled for epilogue partition
gC_tiled = cute.local_tile(gC, (mma_tiler[0], mma_tiler[1]), (0, 0))
pipeline.pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn)
# ══════════════════════════════════════════════════════════
# TMA LOAD WARP (warp 5)
# ══════════════════════════════════════════════════════════
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_atom_a)
cpasync.prefetch_descriptor(tma_atom_b)
cpasync.prefetch_descriptor(tma_atom_c)
ab_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, 1)
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
ab_pipeline.producer_acquire(ab_state)
cute.copy(tma_atom_a, tAgA[(None, ab_state.count)], tAsA[(None, ab_state.index)],
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_state))
cute.copy(tma_atom_b, tBgB[(None, ab_state.count)], tBsB[(None, ab_state.index)],
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_state))
ab_state.advance()
ab_pipeline.producer_tail(ab_state)
# ══════════════════════════════════════════════════════════
# MMA WARP (warp 4)
# ══════════════════════════════════════════════════════════
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
acc_ptr = tmem.retrieve_ptr(Float32)
tCtAcc_base = cute.make_tensor(acc_ptr, tCtAcc_all.layout)
tCtAcc = tCtAcc_base[(None, None, None, 0)]
ab_cstate = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1)
acc_pstate = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, 1)
acc_pipeline.producer_acquire(acc_pstate)
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
for k_tile in range(k_tile_cnt):
ab_pipeline.consumer_wait(ab_cstate, cutlass.Boolean(1))
for kblock in cutlass.range(cute.size(tCrA, mode=[2]), unroll_full=True):
coord = (None, None, kblock, ab_cstate.index)
cute.gemm(tiled_mma, tCtAcc, tCrA[coord], tCrB[coord], tCtAcc)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
ab_pipeline.consumer_release(ab_cstate)
ab_cstate.advance()
acc_pipeline.producer_commit(acc_pstate)
acc_pstate.advance()
acc_pipeline.producer_tail(acc_pstate)
# ══════════════════════════════════════════════════════════
# EPILOGUE WARPS (0..3)
# ══════════════════════════════════════════════════════════
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_accumulator_tmem_cols)
tmem.wait_for_alloc()
acc_ptr = tmem.retrieve_ptr(Float32)
tCtAcc_base = cute.make_tensor(acc_ptr, tCtAcc_all.layout)
c_layout_enum = utils.LayoutEnum.ROW_MAJOR
c_dtype = BFloat16
# TMEM→reg
copy_atom_t2r = sm100_utils.get_tmem_load_op(
self.cta_tile_shape_mnk, c_layout_enum, c_dtype, Float32, epi_tile, False)
tAcc_epi = cute.flat_divide(tCtAcc_base[((None, None), 0, 0, None)], epi_tile)
tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)])
thr_t2r = tiled_copy_t2r.get_slice(tidx)
tTR_tAcc = thr_t2r.partition_S(tAcc_epi)
tTR_rAcc = cute.make_rmem_tensor(
thr_t2r.partition_D(
cute.flat_divide(tCgC[((None, None), 0, 0, None, None)], epi_tile)
)[(None, None, None, 0, 0, 0, 0)].shape, Float32)
tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, c_dtype)
# reg→SMEM
copy_atom_r2s = sm100_utils.get_smem_store_op(c_layout_enum, c_dtype, Float32, tiled_copy_t2r)
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
thr_r2s = tiled_copy_r2s.get_slice(tidx)
tRS_sC = thr_r2s.partition_D(sC)
tRS_rC = tiled_copy_r2s.retile(tTR_rC)
# SMEM→GMEM (TMA)
gC_epi = cute.flat_divide(tCgC[((None, None), 0, 0, None, None)], epi_tile)
bSG_sC, bSG_gC = cpasync.tma_partition(
tma_atom_c, 0, cute.make_layout(1),
cute.group_modes(sC, 0, 2),
cute.group_modes(gC_epi, 0, 2))
acc_cstate = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1)
c_pipeline = pipeline.PipelineTmaStore.create(
num_stages=2,
producer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, 32 * len(self.epilog_warp_ids)))
acc_pipeline.consumer_wait(acc_cstate)
tTR_tAcc_g = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
bSG_gC_g = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC))
for subtile in cutlass.range(cute.size(tTR_tAcc_g.shape, mode=[3])):
cute.copy(tiled_copy_t2r, tTR_tAcc_g[(None, None, None, subtile)], tTR_rAcc)
acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
tRS_rC.store(acc_vec.to(c_dtype))
c_buf = subtile % 2
cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buf)])
cute.arch.fence_proxy("async.shared", space="cta")
self.epilog_sync_barrier.arrive_and_wait()
if warp_idx == self.epilog_warp_ids[0]:
cute.copy(tma_atom_c, bSG_sC[(None, c_buf)], bSG_gC_g[(None, subtile)])
c_pipeline.producer_commit()
c_pipeline.producer_acquire()
self.epilog_sync_barrier.arrive_and_wait()
acc_pipeline.consumer_release(acc_cstate)
tmem.relinquish_alloc_permit()
self.epilog_sync_barrier.arrive_and_wait()
tmem.free(acc_ptr)
c_pipeline.producer_tail()
from cutlass.cute.runtime import make_ptr
from cutlass.cute.nvgpu import cpasync
import cutlass.utils.blackwell_helpers as sm100_utils
def test_minimal_mma():
device = torch.device("cuda")
torch.manual_seed(42)
prob_m, prob_n, prob_k = M, N, K
tA = torch.randn(prob_m, prob_k, dtype=torch.bfloat16, device=device)
tB = torch.randn(prob_n, prob_k, dtype=torch.bfloat16, device=device)
ref = torch.matmul(tA.to(torch.float32), tB.to(torch.float32).T)
tC = torch.zeros(prob_m, prob_n, dtype=torch.bfloat16, device=device)
a_ptr = make_ptr(BFloat16, 0, cute.AddressSpace.gmem, assumed_align=16)
b_ptr = make_ptr(BFloat16, 0, cute.AddressSpace.gmem, assumed_align=16)
c_ptr = make_ptr(BFloat16, 0, cute.AddressSpace.gmem, assumed_align=16)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = MinimalMMAKernel()
compiled = cute.compile(kernel, a_ptr, b_ptr, c_ptr,
cutlass.Int32(prob_m), cutlass.Int32(prob_n), cutlass.Int32(prob_k), stream)
a_ptr_r = make_ptr(BFloat16, tA.data_ptr(), cute.AddressSpace.gmem, assumed_align=16)
b_ptr_r = make_ptr(BFloat16, tB.data_ptr(), cute.AddressSpace.gmem, assumed_align=16)
c_ptr_r = make_ptr(BFloat16, tC.data_ptr(), cute.AddressSpace.gmem, assumed_align=16)
compiled(a_ptr_r, b_ptr_r, c_ptr_r, prob_m, prob_n, prob_k, stream)
torch.cuda.synchronize()
output = tC.to(torch.float32)
cos = torch.nn.functional.cosine_similarity(
output.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (output - ref).abs().max().item()
print(f"Minimal MMA: A({prob_m},{prob_k}) @ B^T({prob_k},{prob_n}) → C({prob_m},{prob_n})")
print(f" Cosine: {cos:.6f}, Max error: {max_err:.6f}")
print(f" {'✅ PASS' if cos >= 0.99 else '❌ FAIL'}")
return cos
if __name__ == "__main__":
test_minimal_mma()

View File

@@ -0,0 +1,376 @@
"""
Stage A: Bare Q@K^T via tcgen05.mma → TMEM → GMEM
Follows the CUTLASS dense_gemm_persistent.py pattern EXACTLY.
BF16 inputs, FP32 accumulator, TMA load/store, warp specialization.
Single tile (no persistent scheduler), cluster (1,1).
"""
import torch
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.cute.runtime import make_ptr
import cuda.bindings.driver as cuda
class StageAQKTKernel:
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
self.acc_dtype = Float32
self.use_2cta_instrs = use_2cta_instrs
self.mma_tiler_mn = mma_tiler_mn
self.mma_tiler = (*mma_tiler_mn, 1)
self.use_tma_store = use_tma_store
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.threads_per_cta = 32 * 6 # 192
self.epilog_sync_bar_id = 1
self.tmem_alloc_sync_bar_id = 2
self.tmem_dealloc_sync_bar_id = 3
def _create_tiled_mma(self):
return utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.a_major_mode, self.b_major_mode,
self.acc_dtype, self.cta_group, self.mma_tiler_mn,
)
def _setup_attributes(self):
# Create pv_mma but DO NOT use it
pv_mma = utils.sm100.make_trivial_tiled_mma(self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major_mode, self.acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
tiled_mma = self._create_tiled_mma()
mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
mma_inst_tile_k = 4
self.mma_tiler = (self.mma_tiler[0], self.mma_tiler[1], mma_inst_shape_k * mma_inst_tile_k)
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler[1],
self.mma_tiler[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((1, 1, 1)), (tiled_mma.thr_id.shape,))
self.num_mcast_ctas_a = 1
self.num_mcast_ctas_b = 1
self.is_a_mcast = False
self.is_b_mcast = False
# Epilogue tile
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.c_dtype)
# Stage counts: 1 AB stage (single tile, no double-buffer), 1 acc stage, 2 C stages
self.num_ab_stage = 1
self.num_acc_stage = 1
self.num_c_stage = 2
# SMEM layouts
self.a_smem_layout_staged = utils.sm100.make_smem_layout_a(
tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage)
self.b_smem_layout_staged = utils.sm100.make_smem_layout_b(
tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage)
self.c_smem_layout_staged = utils.sm100.make_smem_layout_epi(
self.c_dtype, self.c_layout, self.epi_tile, self.num_c_stage)
# TMEM alloc cols
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake, arch="sm_100")
# TMA load bytes
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.a_dtype, a_smem_layout) +
cute.size_in_bytes(self.b_dtype, b_smem_layout)
) * cute.size(tiled_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor,
stream: cuda.CUstream):
self.a_dtype = a.element_type
self.b_dtype = b.element_type
self.c_dtype = c.element_type
self.a_major_mode = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major_mode = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
# Create pv_mma but DO NOT use it
pv_mma = utils.sm100.make_trivial_tiled_mma(self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major_mode, self.acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
tiled_mma = self._create_tiled_mma()
self._setup_attributes()
# TMA load A
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id),
a, a_smem_layout, self.mma_tiler, tiled_mma,
self.cluster_layout_vmnk.shape,
)
# TMA load B
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id),
b, b_smem_layout, self.mma_tiler, tiled_mma,
self.cluster_layout_vmnk.shape,
)
# TMA store C
epi_smem_layout = cute.select(self.c_smem_layout_staged, mode=[0, 1])
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem_layout, self.epi_tile)
self._kernel(
tiled_mma, tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b,
tma_atom_c, tma_tensor_c, self.cluster_layout_vmnk,
self.a_smem_layout_staged, self.b_smem_layout_staged,
self.c_smem_layout_staged, self.epi_tile,
).launch(grid=(1, 1, 1), block=[self.threads_per_cta, 1, 1], stream=stream)
@cute.kernel
def _kernel(self, tiled_mma, tma_atom_a, mA_mkl, tma_atom_b, mB_nkl,
tma_atom_c, mC_mnl, cluster_layout_vmnk,
a_smem_layout_staged, b_smem_layout_staged, c_smem_layout_staged, epi_tile):
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
tidx, _, _ = cute.arch.thread_idx()
use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
is_leader_cta = True # single CTA, always leader
# Prefetch TMA descriptors
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_atom_a)
cpasync.prefetch_descriptor(tma_atom_b)
cpasync.prefetch_descriptor(tma_atom_c)
# ── Shared storage ───────────────────────────────────
@cute.struct
class SharedStorage:
ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc_mbar: cutlass.Int64
tmem_holding_buf: cutlass.Int32
smem = utils.SmemAllocator()
storage = smem.allocate(SharedStorage)
# AB pipeline
ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes,
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
).make_participants()
# ACC pipeline
acc_pipeline = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta_instrs else 1)),
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
)
# TMEM allocator
tmem_alloc_barrier = pipeline.NamedBarrier(
barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)),
)
tmem = utils.TmemAllocator(
storage.tmem_holding_buf.ptr,
barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=self.epilogue_warp_id[0],
is_two_cta=use_2cta_instrs,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
)
pipeline.pipeline_init_arrive(cluster_shape_mn=cluster_layout_vmnk, is_relaxed=True)
# SMEM tensors
sA = smem.allocate_tensor(
element_type=self.a_dtype, layout=a_smem_layout_staged.outer,
byte_alignment=128, swizzle=a_smem_layout_staged.inner)
sB = smem.allocate_tensor(
element_type=self.b_dtype, layout=b_smem_layout_staged.outer,
byte_alignment=128, swizzle=b_smem_layout_staged.inner)
sC = smem.allocate_tensor(
element_type=self.c_dtype, layout=c_smem_layout_staged.outer,
byte_alignment=128, swizzle=c_smem_layout_staged.inner)
# Partition global tensors
gA_mkl = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None))
gB_nkl = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None))
gC_mnl = cute.local_tile(mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None))
k_tile_cnt = cute.size(gA_mkl, mode=[3])
# Partition for TiledMMA
thr_mma = tiled_mma.get_slice(0) # leader CTA
tCgA = thr_mma.partition_A(gA_mkl)
tCgB = thr_mma.partition_B(gB_nkl)
tCgC = thr_mma.partition_C(gC_mnl)
# TMA partition A/B
a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape)
tAsA, tAgA = cpasync.tma_partition(
tma_atom_a, 0, a_cta_layout,
cute.group_modes(sA, 0, 3), cute.group_modes(tCgA, 0, 3))
b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape)
tBsB, tBgB = cpasync.tma_partition(
tma_atom_b, 0, b_cta_layout,
cute.group_modes(sB, 0, 3), cute.group_modes(tCgB, 0, 3))
# Slice to tile coord (0, 0, 0)
tAgA_slice = tAgA[(None, 0, None, 0)]
tBgB_slice = tBgB[(None, 0, None, 0)]
# MMA fragments
tCrA = tiled_mma.make_fragment_A(sA)
tCrB = tiled_mma.make_fragment_B(sB)
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cluster_layout_vmnk)
# ══════════════════════════════════════════════════════════
# TMA LOAD WARP (warp 5)
# ══════════════════════════════════════════════════════════
if warp_idx == self.tma_warp_id:
ab_producer.reset()
peek_ab_empty_status = ab_producer.try_acquire()
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
handle = ab_producer.acquire_and_advance(peek_ab_empty_status)
cute.copy(tma_atom_a, tAgA_slice[(None, handle.count)], tAsA[(None, handle.index)],
tma_bar_ptr=handle.barrier)
cute.copy(tma_atom_b, tBgB_slice[(None, handle.count)], tBsB[(None, handle.index)],
tma_bar_ptr=handle.barrier)
peek_ab_empty_status = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt:
peek_ab_empty_status = ab_producer.try_acquire()
ab_producer.tail()
# ══════════════════════════════════════════════════════════
# MMA WARP (warp 4)
# ══════════════════════════════════════════════════════════
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
tCtAcc = tCtAcc_base[(None, None, None, 0)]
ab_consumer.reset()
peek_ab_full_status = cutlass.Boolean(1)
if is_leader_cta:
peek_ab_full_status = ab_consumer.try_wait()
acc_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_acc_stage)
if is_leader_cta:
acc_pipeline.producer_acquire(acc_producer_state)
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
for k_tile in range(k_tile_cnt):
if is_leader_cta:
handle = ab_consumer.wait_and_advance(peek_ab_full_status)
num_kblocks = cute.size(tCrA, mode=[2])
for kblk_idx in cutlass.range(num_kblocks, unroll_full=True):
kblk_crd = (None, None, kblk_idx, handle.index)
cute.gemm(tiled_mma, tCtAcc, tCrA[kblk_crd], tCrB[kblk_crd], tCtAcc)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
handle.release()
peek_ab_full_status = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt:
peek_ab_full_status = ab_consumer.try_wait()
if is_leader_cta:
acc_pipeline.producer_commit(acc_producer_state)
acc_producer_state.advance()
acc_pipeline.producer_tail(acc_producer_state)
# ══════════════════════════════════════════════════════════
# EPILOGUE WARPS (0..3)
# ══════════════════════════════════════════════════════════
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
acc_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipeline = pipeline.PipelineTmaStore.create(
num_stages=self.num_c_stage, producer_group=c_producer_group)
# Use the reference epilogue implementation
mma_tile_coord_mnl = (0, 0, 0)
epilogue_op = const_expr(lambda x: x)
num_tiles_executed = 0
acc_consumer_state = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_atom_c, tCtAcc_base, sC, tCgC,
epi_tile, num_tiles_executed, epilogue_op,
mma_tile_coord_mnl, acc_consumer_state, acc_pipeline, c_pipeline)
c_pipeline.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test_stage_a():
"""Test Stage A: Q @ K^T → TMEM → GMEM"""
device = torch.device("cuda")
torch.manual_seed(42)
m, n, k = 128, 128, 512
# Tensors must be 3D (M, K, L) for the CUTLASS pattern
a = torch.randn(m, k, 1, dtype=torch.bfloat16, device="cuda")
b = torch.randn(n, k, 1, dtype=torch.bfloat16, device="cuda")
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device="cuda")
ref = a[:, :, 0].float() @ b[:, :, 0].float().T
# Create cute tensors
import cutlass.torch as cutlass_torch
mA = cutlass_torch.from_dlpack(a).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(a))
mB = cutlass_torch.from_dlpack(b).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(b))
mC = cutlass_torch.from_dlpack(c).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageAQKTKernel(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
compiled = cute.compile(kernel, mA, mB, mC, stream)
# Run with the same tensors
compiled(mA, mB, mC, stream)
torch.cuda.synchronize()
output = c[:, :, 0].float()
cos = torch.nn.functional.cosine_similarity(
output.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (output - ref).abs().max().item()
print("Stage A: Q({},{}) @ K^T({}, {}) -> S({}, {})".format(m, k, k, n, m, n))
print(" Cosine: {:.6f}, Max error: {:.6f}".format(cos, max_err))
print(" {}".format("PASS" if cos >= 0.99 else "FAIL"))
return cos
if __name__ == "__main__":
test_stage_a()

View File

@@ -0,0 +1,374 @@
"""
Stage A: Bare Q@K^T via tcgen05.mma → TMEM → GMEM
Follows the CUTLASS dense_gemm_persistent.py pattern EXACTLY.
BF16 inputs, FP32 accumulator, TMA load/store, warp specialization.
Single tile (no persistent scheduler), cluster (1,1).
"""
import torch
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.cute.runtime import make_ptr
import cuda.bindings.driver as cuda
class StageAWithPVParam:
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
self.acc_dtype = Float32
self.use_2cta_instrs = use_2cta_instrs
self.mma_tiler_mn = mma_tiler_mn
self.mma_tiler = (*mma_tiler_mn, 1)
self.use_tma_store = use_tma_store
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.threads_per_cta = 32 * 6 # 192
self.epilog_sync_bar_id = 1
self.tmem_alloc_sync_bar_id = 2
self.tmem_dealloc_sync_bar_id = 3
def _create_tiled_mma(self):
return utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.a_major_mode, self.b_major_mode,
self.acc_dtype, self.cta_group, self.mma_tiler_mn,
)
def _setup_attributes(self):
tiled_mma = self._create_tiled_mma()
pv_mma = utils.sm100.make_trivial_tiled_mma(self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major_mode, self.acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
mma_inst_tile_k = 4
self.mma_tiler = (self.mma_tiler[0], self.mma_tiler[1], mma_inst_shape_k * mma_inst_tile_k)
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler[1],
self.mma_tiler[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((1, 1, 1)), (tiled_mma.thr_id.shape,))
self.num_mcast_ctas_a = 1
self.num_mcast_ctas_b = 1
self.is_a_mcast = False
self.is_b_mcast = False
# Epilogue tile
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.c_dtype)
# Stage counts: 1 AB stage (single tile, no double-buffer), 1 acc stage, 2 C stages
self.num_ab_stage = 1
self.num_acc_stage = 1
self.num_c_stage = 2
# SMEM layouts
self.a_smem_layout_staged = utils.sm100.make_smem_layout_a(
tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage)
self.b_smem_layout_staged = utils.sm100.make_smem_layout_b(
tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage)
self.c_smem_layout_staged = utils.sm100.make_smem_layout_epi(
self.c_dtype, self.c_layout, self.epi_tile, self.num_c_stage)
# TMEM alloc cols
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake, arch="sm_100")
# TMA load bytes
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.a_dtype, a_smem_layout) +
cute.size_in_bytes(self.b_dtype, b_smem_layout)
) * cute.size(tiled_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor,
stream: cuda.CUstream):
self.a_dtype = a.element_type
self.b_dtype = b.element_type
self.c_dtype = c.element_type
self.a_major_mode = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major_mode = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
tiled_mma = self._create_tiled_mma()
pv_mma = utils.sm100.make_trivial_tiled_mma(self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major_mode, self.acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup_attributes()
# TMA load A
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id),
a, a_smem_layout, self.mma_tiler, tiled_mma,
self.cluster_layout_vmnk.shape,
)
# TMA load B
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id),
b, b_smem_layout, self.mma_tiler, tiled_mma,
self.cluster_layout_vmnk.shape,
)
# TMA store C
epi_smem_layout = cute.select(self.c_smem_layout_staged, mode=[0, 1])
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem_layout, self.epi_tile)
self._kernel(pv_mma,
tiled_mma, pv_mma, tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b,
tma_atom_c, tma_tensor_c, self.cluster_layout_vmnk,
self.a_smem_layout_staged, self.b_smem_layout_staged,
self.c_smem_layout_staged, self.epi_tile,
).launch(grid=(1, 1, 1), block=[self.threads_per_cta, 1, 1], stream=stream)
@cute.kernel
def _kernel(self, tiled_mma, pv_mma, tma_atom_a, mA_mkl, tma_atom_b, mB_nkl,
tma_atom_c, mC_mnl, cluster_layout_vmnk,
a_smem_layout_staged, b_smem_layout_staged, c_smem_layout_staged, epi_tile):
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
tidx, _, _ = cute.arch.thread_idx()
use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
is_leader_cta = True # single CTA, always leader
# Prefetch TMA descriptors
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_atom_a)
cpasync.prefetch_descriptor(tma_atom_b)
cpasync.prefetch_descriptor(tma_atom_c)
# ── Shared storage ───────────────────────────────────
@cute.struct
class SharedStorage:
ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc_mbar: cutlass.Int64
tmem_holding_buf: cutlass.Int32
smem = utils.SmemAllocator()
storage = smem.allocate(SharedStorage)
# AB pipeline
ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes,
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
).make_participants()
# ACC pipeline
acc_pipeline = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta_instrs else 1)),
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
)
# TMEM allocator
tmem_alloc_barrier = pipeline.NamedBarrier(
barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)),
)
tmem = utils.TmemAllocator(
storage.tmem_holding_buf.ptr,
barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=self.epilogue_warp_id[0],
is_two_cta=use_2cta_instrs,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
)
pipeline.pipeline_init_arrive(cluster_shape_mn=cluster_layout_vmnk, is_relaxed=True)
# SMEM tensors
sA = smem.allocate_tensor(
element_type=self.a_dtype, layout=a_smem_layout_staged.outer,
byte_alignment=128, swizzle=a_smem_layout_staged.inner)
sB = smem.allocate_tensor(
element_type=self.b_dtype, layout=b_smem_layout_staged.outer,
byte_alignment=128, swizzle=b_smem_layout_staged.inner)
sC = smem.allocate_tensor(
element_type=self.c_dtype, layout=c_smem_layout_staged.outer,
byte_alignment=128, swizzle=c_smem_layout_staged.inner)
# Partition global tensors
gA_mkl = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None))
gB_nkl = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None))
gC_mnl = cute.local_tile(mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None))
k_tile_cnt = cute.size(gA_mkl, mode=[3])
# Partition for TiledMMA
thr_mma = tiled_mma.get_slice(0) # leader CTA
tCgA = thr_mma.partition_A(gA_mkl)
tCgB = thr_mma.partition_B(gB_nkl)
tCgC = thr_mma.partition_C(gC_mnl)
# TMA partition A/B
a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape)
tAsA, tAgA = cpasync.tma_partition(
tma_atom_a, 0, a_cta_layout,
cute.group_modes(sA, 0, 3), cute.group_modes(tCgA, 0, 3))
b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape)
tBsB, tBgB = cpasync.tma_partition(
tma_atom_b, 0, b_cta_layout,
cute.group_modes(sB, 0, 3), cute.group_modes(tCgB, 0, 3))
# Slice to tile coord (0, 0, 0)
tAgA_slice = tAgA[(None, 0, None, 0)]
tBgB_slice = tBgB[(None, 0, None, 0)]
# MMA fragments
tCrA = tiled_mma.make_fragment_A(sA)
tCrB = tiled_mma.make_fragment_B(sB)
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cluster_layout_vmnk)
# ══════════════════════════════════════════════════════════
# TMA LOAD WARP (warp 5)
# ══════════════════════════════════════════════════════════
if warp_idx == self.tma_warp_id:
ab_producer.reset()
peek_ab_empty_status = ab_producer.try_acquire()
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
handle = ab_producer.acquire_and_advance(peek_ab_empty_status)
cute.copy(tma_atom_a, tAgA_slice[(None, handle.count)], tAsA[(None, handle.index)],
tma_bar_ptr=handle.barrier)
cute.copy(tma_atom_b, tBgB_slice[(None, handle.count)], tBsB[(None, handle.index)],
tma_bar_ptr=handle.barrier)
peek_ab_empty_status = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt:
peek_ab_empty_status = ab_producer.try_acquire()
ab_producer.tail()
# ══════════════════════════════════════════════════════════
# MMA WARP (warp 4)
# ══════════════════════════════════════════════════════════
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
tCtAcc = tCtAcc_base[(None, None, None, 0)]
ab_consumer.reset()
peek_ab_full_status = cutlass.Boolean(1)
if is_leader_cta:
peek_ab_full_status = ab_consumer.try_wait()
acc_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_acc_stage)
if is_leader_cta:
acc_pipeline.producer_acquire(acc_producer_state)
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
for k_tile in range(k_tile_cnt):
if is_leader_cta:
handle = ab_consumer.wait_and_advance(peek_ab_full_status)
num_kblocks = cute.size(tCrA, mode=[2])
for kblk_idx in cutlass.range(num_kblocks, unroll_full=True):
kblk_crd = (None, None, kblk_idx, handle.index)
cute.gemm(tiled_mma, tCtAcc, tCrA[kblk_crd], tCrB[kblk_crd], tCtAcc)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
handle.release()
peek_ab_full_status = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt:
peek_ab_full_status = ab_consumer.try_wait()
if is_leader_cta:
acc_pipeline.producer_commit(acc_producer_state)
acc_producer_state.advance()
acc_pipeline.producer_tail(acc_producer_state)
# ══════════════════════════════════════════════════════════
# EPILOGUE WARPS (0..3)
# ══════════════════════════════════════════════════════════
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
acc_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipeline = pipeline.PipelineTmaStore.create(
num_stages=self.num_c_stage, producer_group=c_producer_group)
# Use the reference epilogue implementation
mma_tile_coord_mnl = (0, 0, 0)
epilogue_op = const_expr(lambda x: x)
num_tiles_executed = 0
acc_consumer_state = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_atom_c, tCtAcc_base, sC, tCgC,
epi_tile, num_tiles_executed, epilogue_op,
mma_tile_coord_mnl, acc_consumer_state, acc_pipeline, c_pipeline)
c_pipeline.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test_stage_a_with_pv_param():
"""Test Stage A: Q @ K^T → TMEM → GMEM"""
device = torch.device("cuda")
torch.manual_seed(42)
m, n, k = 128, 128, 512
# Tensors must be 3D (M, K, L) for the CUTLASS pattern
a = torch.randn(m, k, 1, dtype=torch.bfloat16, device="cuda")
b = torch.randn(n, k, 1, dtype=torch.bfloat16, device="cuda")
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device="cuda")
ref = a[:, :, 0].float() @ b[:, :, 0].float().T
# Create cute tensors
import cutlass.torch as cutlass_torch
mA = cutlass_torch.from_dlpack(a).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(a))
mB = cutlass_torch.from_dlpack(b).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(b))
mC = cutlass_torch.from_dlpack(c).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageAWithPVParam(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
compiled = cute.compile(kernel, mA, mB, mC, stream)
# Run with the same tensors
compiled(mA, mB, mC, stream)
torch.cuda.synchronize()
output = c[:, :, 0].float()
cos = torch.nn.functional.cosine_similarity(
output.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (output - ref).abs().max().item()
print("Stage A: Q({},{}) @ K^T({}, {}) -> S({}, {})".format(m, k, k, n, m, n))
print(" Cosine: {:.6f}, Max error: {:.6f}".format(cos, max_err))
print(" {}".format("PASS" if cos >= 0.99 else "FAIL"))
return cos
if __name__ == "__main__":
test_stage_a_with_pv_param()

632
tests/test_stage_a_qk.py Normal file
View File

@@ -0,0 +1,632 @@
"""
Stage A: Bare Q@K^T via tcgen05.mma → TMEM → GMEM
Validates the tcgen05 MMA path with zero attention logic.
Following dense_blockscaled_gemm_persistent.py for all Blackwell idioms.
Shape: Q(M=128, K=512) @ K^T(512, N=16) → S(128, 16) fp32 output to GMEM.
One CTA. 6 warps: 4 epilogue, 1 MMA, 1 TMA load.
Pipeline: TMA loads Q and K into SMEM, MMA warp issues tcgen05.mma with TMEM accumulator,
epilogue warps TMEM→reg→SMEM→TMA→GMEM.
"""
import torch
import math
try:
import cutlass
import cutlass.cute as cute
import cutlass.torch as cutlass_torch
import cutlass.utils as utils
import cutlass.pipeline as pipeline
import cutlass.utils.blackwell_helpers as sm100_utils
from cutlass.cute.nvgpu import tcgen05, cpasync
from cutlass import BFloat16, Float32
from cutlass.cute.runtime import make_ptr
from typing import Tuple
import cuda.bindings.driver as cuda
HAS_CUTEDSL = True
except ImportError:
HAS_CUTEDSL = False
print("WARNING: CuTeDSL not available")
# ── Problem dimensions ────────────────────────────────────────────────
HG = 128 # query heads per CTA (M dimension)
KT = 64 # KV positions per tile (N dimension) — minimum for tcgen05 is 64
HD = 512 # head dim (K dimension)
# ── Warp specialization (mirrors the dense GEMM) ─────────────────────
EPILOGUE_WARP_IDS = (0, 1, 2, 3)
MMA_WARP_ID = 4
TMA_WARP_ID = 5
THREADS_PER_WARP = 32
NUM_WARPS = 6
NUM_THREADS = THREADS_PER_WARP * NUM_WARPS # 192
class StageAQKTKernel:
"""Stage A: Q @ K^T → TMEM → GMEM, no softmax, no PV GEMM.
This is dense_blockscaled_gemm_persistent.py stripped to attention shapes:
- BF16 inputs, FP32 accumulator
- No block scaling (SFA/SFB) — plain bf16 MMA
- No persistent scheduler — one tile per CTA
- Single AB stage (one KV tile, no double-buffer needed for Stage A)
- TMA load for Q and K, tcgen05.mma, TMEM→reg→SMEM→GMEM epilogue
"""
def __init__(self, mma_tiler_mn=(HG, KT)):
self.mma_tiler_mn = mma_tiler_mn
self.use_2cta_instrs = mma_tiler_mn[0] == 256
self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
# Warp IDs and thread count
self.epilog_warp_id = EPILOGUE_WARP_IDS
self.mma_warp_id = MMA_WARP_ID
self.tma_warp_id = TMA_WARP_ID
self.threads_per_warp = THREADS_PER_WARP
self.threads_per_cta = NUM_THREADS
# Named barriers
self.epilog_sync_barrier = pipeline.NamedBarrier(
barrier_id=1,
num_threads=THREADS_PER_WARP * len(EPILOGUE_WARP_IDS),
)
self.tmem_alloc_barrier = pipeline.NamedBarrier(
barrier_id=2,
num_threads=THREADS_PER_WARP * (1 + len(EPILOGUE_WARP_IDS)),
)
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
def _setup_attributes(self, a_dtype, b_dtype, c_dtype, a_major_mode, b_major_mode, c_layout):
"""Setup attributes that depend on input types."""
self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.c_dtype = c_dtype
self.acc_dtype = Float32
self.a_major_mode = a_major_mode
self.b_major_mode = b_major_mode
self.c_layout = c_layout
# Create tiled MMA — plain BF16 (no block scaling)
self.tiled_mma = sm100_utils.make_trivial_tiled_mma(
a_dtype, b_dtype,
a_major_mode, b_major_mode,
Float32, # acc_dtype
self.cta_group,
self.mma_tiler_mn,
)
atom_thr_size = cute.size(self.tiled_mma.thr_id.shape)
mma_inst_shape_k = cute.size(self.tiled_mma.shape_mnk, mode=[2])
mma_inst_tile_k = 4
self.mma_tiler = (
self.mma_tiler_mn[0],
self.mma_tiler_mn[1],
mma_inst_shape_k * mma_inst_tile_k,
)
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // atom_thr_size,
self.mma_tiler[1],
self.mma_tiler[2],
)
# Cluster shape (1,1) — no clustering for attention
self.cluster_shape_mn = (1, 1)
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((1, 1, 1)),
(self.tiled_mma.thr_id.shape,),
)
self.num_mcast_ctas_a = 1
self.num_mcast_ctas_b = 1
self.is_a_mcast = False
self.is_b_mcast = False
# Epilogue tile
self.epi_tile = sm100_utils.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk,
self.use_2cta_instrs,
self.c_layout,
self.c_dtype,
)
self.epi_tile_n = cute.size(self.epi_tile[1])
# Stage counts
self.num_ab_stage = 1
self.num_acc_stage = 1
self.num_c_stage = 2
# SMEM layouts
self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
self.tiled_mma, self.mma_tiler, a_dtype, self.num_ab_stage,
)
self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
self.tiled_mma, self.mma_tiler, b_dtype, self.num_ab_stage,
)
self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi(
c_dtype, self.c_layout, self.epi_tile, self.num_c_stage,
)
# TMEM columns for accumulator
self.num_accumulator_tmem_cols = self.cta_tile_shape_mnk[1]
self.overlapping_accum = False
@cute.jit
def __call__(
self,
a_ptr: cute.Pointer,
b_ptr: cute.Pointer,
c_ptr: cute.Pointer,
problem_m: cutlass.Int32,
problem_n: cutlass.Int32,
problem_k: cutlass.Int32,
stream: cuda.CUstream,
):
a_dtype = a_ptr.value_type
b_dtype = b_ptr.value_type
c_dtype = c_ptr.value_type
a_major_mode, b_major_mode, c_layout = self._layouts
self._setup_attributes(a_dtype, b_dtype, c_dtype, a_major_mode, b_major_mode, c_layout)
m, n, k = problem_m, problem_n, problem_k
# Make GMEM tensors — include batch dim (l=1) for local_tile compatibility
a_layout = cute.make_ordered_layout((m, cute.assume(k, 32), 1), order=(1, 0, 2))
b_layout = cute.make_ordered_layout((n, cute.assume(k, 32), 1), order=(1, 0, 2))
c_layout_obj = cute.make_ordered_layout((m, cute.assume(n, 32), 1), order=(1, 0, 2))
mA = cute.make_tensor(a_ptr, a_layout)
mB = cute.make_tensor(b_ptr, b_layout)
mC = cute.make_tensor(c_ptr, c_layout_obj)
# TMA descriptors
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, self.tiled_mma.thr_id),
mA, a_smem_layout, self.mma_tiler, self.tiled_mma,
self.cluster_layout_vmnk.shape,
)
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, self.tiled_mma.thr_id),
mB, b_smem_layout, self.mma_tiler, self.tiled_mma,
self.cluster_layout_vmnk.shape,
)
a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout)
b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout)
self.num_tma_load_bytes = (a_copy_size + b_copy_size) * cute.size(self.tiled_mma.thr_id.shape)
epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0))
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(),
mC, epi_smem_layout, self.epi_tile,
)
@cute.struct
class SharedStorage:
ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
tmem_dealloc_mbar: cutlass.Int64
tmem_holding_buf: cutlass.Int32
sA: cute.struct.Align[
cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)],
1024,
]
sB: cute.struct.Align[
cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)],
1024,
]
sC: cute.struct.Align[
cute.struct.MemRange[self.c_dtype, cute.cosize(self.c_smem_layout_staged.outer)],
1024,
]
self.shared_storage = SharedStorage
self._kernel(
self.tiled_mma,
tma_atom_a, tma_tensor_a,
tma_atom_b, tma_tensor_b,
tma_atom_c, tma_tensor_c,
self.cluster_layout_vmnk,
self.a_smem_layout_staged,
self.b_smem_layout_staged,
self.c_smem_layout_staged,
self.epi_tile,
).launch(
grid=(1, 1, 1),
block=[self.threads_per_cta, 1, 1],
stream=stream,
min_blocks_per_mp=1,
)
@cute.kernel
def _kernel(
self,
tiled_mma,
tma_atom_a, mA_mkl,
tma_atom_b, mB_nkl,
tma_atom_c, mC_mnl,
cluster_layout_vmnk,
a_smem_layout_staged,
b_smem_layout_staged,
c_smem_layout_staged,
epi_tile,
):
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
tidx, _, _ = cute.arch.thread_idx()
use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
is_leader_cta = True
# ── Shared memory ───────────────────────────────────────
smem = utils.SmemAllocator()
storage = smem.allocate(self.shared_storage)
# Init AB pipeline
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_tma_producer)
ab_pipeline = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
num_stages=self.num_ab_stage,
producer_group=ab_pipeline_producer_group,
consumer_group=ab_pipeline_consumer_group,
tx_count=self.num_tma_load_bytes,
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
)
# Init accumulator pipeline
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
num_acc_consumer_threads = self.threads_per_warp * len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1)
acc_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_acc_consumer_threads)
acc_pipeline = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
num_stages=self.num_acc_stage,
producer_group=acc_pipeline_producer_group,
consumer_group=acc_pipeline_consumer_group,
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
)
# TMEM allocator
tmem = utils.TmemAllocator(
storage.tmem_holding_buf.ptr,
barrier_for_retrieve=self.tmem_alloc_barrier,
allocator_warp_id=self.epilog_warp_id[0],
is_two_cta=use_2cta_instrs,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
)
pipeline.pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True)
# ── SMEM tensors ────────────────────────────────────────
sA = storage.sA.get_tensor(
a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner
)
sB = storage.sB.get_tensor(
b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner
)
sC = storage.sC.get_tensor(
c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner
)
# ── Partition global tensors ────────────────────────────
gA_mkl = cute.local_tile(
mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)
)
gB_nkl = cute.local_tile(
mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)
)
gC_mnl = cute.local_tile(
mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)
)
k_tile_cnt = cute.size(gA_mkl, mode=[3])
mma_tile_coord_v = 0
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
tCgA = thr_mma.partition_A(gA_mkl)
tCgB = thr_mma.partition_B(gB_nkl)
tCgC = thr_mma.partition_C(gC_mnl)
# TMA partition for A and B
a_cta_layout = cute.make_layout(1)
tAsA, tAgA = cpasync.tma_partition(
tma_atom_a, 0, a_cta_layout,
cute.group_modes(sA, 0, 3),
cute.group_modes(tCgA, 0, 3),
)
b_cta_layout = cute.make_layout(1)
tBsB, tBgB = cpasync.tma_partition(
tma_atom_b, 0, b_cta_layout,
cute.group_modes(sB, 0, 3),
cute.group_modes(tCgB, 0, 3),
)
# Slice to the single tile coordinate
# (mma_tile_coord_mnl = (0,0,0) for single-tile)
tAgA_slice = tAgA[(None, 0, None, 0)]
tBgB_slice = tBgB[(None, 0, None, 0)]
# MMA fragments
tCrA = tiled_mma.make_fragment_A(sA)
tCrB = tiled_mma.make_fragment_B(sB)
# TMEM accumulator
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(
cute.append(acc_shape, self.num_acc_stage)
)
pipeline.pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn)
# ══════════════════════════════════════════════════════════
# TMA LOAD WARP (warp 5)
# ══════════════════════════════════════════════════════════
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_atom_a)
cpasync.prefetch_descriptor(tma_atom_b)
cpasync.prefetch_descriptor(tma_atom_c)
ab_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_ab_stage
)
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
ab_pipeline.producer_acquire(ab_producer_state)
cute.copy(
tma_atom_a,
tAgA_slice[(None, ab_producer_state.count)],
tAsA[(None, ab_producer_state.index)],
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
)
cute.copy(
tma_atom_b,
tBgB_slice[(None, ab_producer_state.count)],
tBsB[(None, ab_producer_state.index)],
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
)
ab_producer_state.advance()
ab_pipeline.producer_tail(ab_producer_state)
# ══════════════════════════════════════════════════════════
# MMA WARP (warp 4)
# ══════════════════════════════════════════════════════════
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
tCtAcc = tCtAcc_base[(None, None, None, 0)]
ab_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_ab_stage
)
acc_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_acc_stage
)
acc_pipeline.producer_acquire(acc_producer_state)
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
for k_tile in range(k_tile_cnt):
if is_leader_cta:
ab_pipeline.consumer_wait(ab_consumer_state, cutlass.Boolean(1))
num_kblocks = cute.size(tCrA, mode=[2])
for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
kblock_coord = (None, None, kblock_idx, ab_consumer_state.index)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[kblock_coord],
tCrB[kblock_coord],
tCtAcc,
)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
if is_leader_cta:
ab_pipeline.consumer_release(ab_consumer_state)
ab_consumer_state.advance()
if is_leader_cta:
acc_pipeline.producer_commit(acc_producer_state)
acc_producer_state.advance()
acc_pipeline.producer_tail(acc_producer_state)
# ══════════════════════════════════════════════════════════
# EPILOGUE WARPS (0..3)
# ══════════════════════════════════════════════════════════
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_accumulator_tmem_cols)
tmem.wait_for_alloc()
acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
epi_tidx = tidx
copy_atom_t2r = sm100_utils.get_tmem_load_op(
self.cta_tile_shape_mnk,
self.c_layout,
self.c_dtype,
self.acc_dtype,
epi_tile,
use_2cta_instrs,
)
tAcc_epi = cute.flat_divide(
tCtAcc_base[((None, None), 0, 0, None)],
epi_tile,
)
tiled_copy_t2r = tcgen05.make_tmem_copy(
copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]
)
thr_copy_t2r = tiled_copy_t2r.get_slice(epi_tidx)
tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
tTR_rAcc = cute.make_rmem_tensor(
thr_copy_t2r.partition_D(
cute.flat_divide(
tCgC[((None, None), 0, 0, None, None, None)], epi_tile
)
)[(None, None, None, 0, 0, 0, 0, 0)].shape,
self.acc_dtype,
)
tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype)
copy_atom_r2s = sm100_utils.get_smem_store_op(
self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r
)
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
thr_copy_r2s = tiled_copy_r2s.get_slice(epi_tidx)
tRS_sC = thr_copy_r2s.partition_D(sC)
tRS_rC = tiled_copy_r2s.retile(tTR_rC)
gC_epi = cute.flat_divide(
tCgC[((None, None), 0, 0, None, None, None)], epi_tile
)
sC_for_tma = cute.group_modes(sC, 0, 2)
gC_for_tma = cute.group_modes(gC_epi, 0, 2)
bSG_sC, bSG_gC = cpasync.tma_partition(
tma_atom_c, 0, cute.make_layout(1),
sC_for_tma, gC_for_tma,
)
acc_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage
)
c_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
self.threads_per_warp * len(self.epilog_warp_id),
)
c_pipeline = pipeline.PipelineTmaStore.create(
num_stages=self.num_c_stage,
producer_group=c_producer_group,
)
acc_pipeline.consumer_wait(acc_consumer_state)
tTR_tAcc_g = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
bSG_gC_g = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC))
subtile_cnt = cute.size(tTR_tAcc_g.shape, mode=[3])
for subtile_idx in cutlass.range(subtile_cnt):
tTR_tAcc_mn = tTR_tAcc_g[(None, None, None, subtile_idx)]
cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
tRS_rC.store(acc_vec.to(self.c_dtype))
c_buffer = subtile_idx % self.num_c_stage
cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)])
cute.arch.fence_proxy("async.shared", space="cta")
self.epilog_sync_barrier.arrive_and_wait()
if warp_idx == self.epilog_warp_id[0]:
cute.copy(
tma_atom_c,
bSG_sC[(None, c_buffer)],
bSG_gC_g[(None, subtile_idx)],
)
c_pipeline.producer_commit()
c_pipeline.producer_acquire()
self.epilog_sync_barrier.arrive_and_wait()
acc_pipeline.consumer_release(acc_consumer_state)
acc_consumer_state.advance()
tmem.relinquish_alloc_permit()
self.epilog_sync_barrier.arrive_and_wait()
tmem.free(acc_tmem_ptr)
c_pipeline.producer_tail()
def test_stage_a():
"""Test Stage A: Q @ K^T via tcgen05.mma → TMEM → GMEM."""
if not HAS_CUTEDSL:
print("CuTeDSL not available, skipping")
return
device = torch.device("cuda")
torch.manual_seed(42)
prob_m, prob_n, prob_k = HG, KT, HD
tQ = torch.randn(prob_m, prob_k, dtype=torch.bfloat16, device=device)
tK = torch.randn(prob_n, prob_k, dtype=torch.bfloat16, device=device)
ref = torch.matmul(tQ.to(torch.float32), tK.to(torch.float32).T)
tC = torch.zeros(prob_m, prob_n, dtype=torch.bfloat16, device=device)
# Compile using make_ptr pattern (like dense GEMM)
a_ptr = make_ptr(BFloat16, 0, cute.AddressSpace.gmem, assumed_align=16)
b_ptr = make_ptr(BFloat16, 0, cute.AddressSpace.gmem, assumed_align=16)
c_ptr = make_ptr(BFloat16, 0, cute.AddressSpace.gmem, assumed_align=16)
a_major_mode = tcgen05.OperandMajorMode.K
b_major_mode = tcgen05.OperandMajorMode.K
c_layout = utils.LayoutEnum.ROW_MAJOR
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
# Compile using make_ptr pattern (like dense GEMM)
a_ptr_fake = make_ptr(BFloat16, 0, cute.AddressSpace.gmem, assumed_align=16)
b_ptr_fake = make_ptr(BFloat16, 0, cute.AddressSpace.gmem, assumed_align=16)
c_ptr_fake = make_ptr(BFloat16, 0, cute.AddressSpace.gmem, assumed_align=16)
a_major_mode = tcgen05.OperandMajorMode.K
b_major_mode = tcgen05.OperandMajorMode.K
c_layout = utils.LayoutEnum.ROW_MAJOR
kernel = StageAQKTKernel()
kernel._layouts = (a_major_mode, b_major_mode, c_layout)
compiled = cute.compile(
kernel,
a_ptr_fake, b_ptr_fake, c_ptr_fake,
cutlass.Int32(prob_m), cutlass.Int32(prob_n), cutlass.Int32(prob_k),
stream,
)
# Create runtime pointers from torch tensor data
a_ptr = make_ptr(BFloat16, tQ.data_ptr(), cute.AddressSpace.gmem, assumed_align=16)
b_ptr = make_ptr(BFloat16, tK.data_ptr(), cute.AddressSpace.gmem, assumed_align=16)
c_ptr = make_ptr(BFloat16, tC.data_ptr(), cute.AddressSpace.gmem, assumed_align=16)
compiled(
a_ptr, b_ptr, c_ptr,
cutlass.Int32(prob_m), cutlass.Int32(prob_n), cutlass.Int32(prob_k),
stream,
)
torch.cuda.synchronize()
output = tC.to(torch.float32)
cos = torch.nn.functional.cosine_similarity(
output.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
).item()
max_err = (output - ref).abs().max().item()
mean_err = (output - ref).abs().mean().item()
print(f"Stage A: Q({prob_m},{prob_k}) @ K^T({prob_k},{prob_n}) → S({prob_m},{prob_n})")
print(f" Cosine similarity: {cos:.6f}")
print(f" Max absolute error: {max_err:.6f}")
print(f" Mean absolute error: {mean_err:.6f}")
if cos >= 0.99:
print(" ✅ PASS")
else:
print(" ❌ FAIL — cosine < 0.99")
return cos
if __name__ == "__main__":
test_stage_a()

372
tests/test_stage_a_v2.py Normal file
View File

@@ -0,0 +1,372 @@
"""
Stage A: Bare Q@K^T via tcgen05.mma → TMEM → GMEM
Follows the CUTLASS dense_gemm_persistent.py pattern EXACTLY.
BF16 inputs, FP32 accumulator, TMA load/store, warp specialization.
Single tile (no persistent scheduler), cluster (1,1).
"""
import torch
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.cute.runtime import make_ptr
import cuda.bindings.driver as cuda
class StageAQKTKernel:
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
self.acc_dtype = Float32
self.use_2cta_instrs = use_2cta_instrs
self.mma_tiler_mn = mma_tiler_mn
self.mma_tiler = (*mma_tiler_mn, 1)
self.use_tma_store = use_tma_store
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.threads_per_cta = 32 * 6 # 192
self.epilog_sync_bar_id = 1
self.tmem_alloc_sync_bar_id = 2
self.tmem_dealloc_sync_bar_id = 3
def _create_tiled_mma(self):
return utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.a_major_mode, self.b_major_mode,
self.acc_dtype, self.cta_group, self.mma_tiler_mn,
)
def _setup_attributes(self):
tiled_mma = self._create_tiled_mma()
mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
mma_inst_tile_k = 4
self.mma_tiler = (self.mma_tiler[0], self.mma_tiler[1], mma_inst_shape_k * mma_inst_tile_k)
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler[1],
self.mma_tiler[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((1, 1, 1)), (tiled_mma.thr_id.shape,))
self.num_mcast_ctas_a = 1
self.num_mcast_ctas_b = 1
self.is_a_mcast = False
self.is_b_mcast = False
# Epilogue tile
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.c_dtype)
# Stage counts: 1 AB stage (single tile, no double-buffer), 1 acc stage, 2 C stages
self.num_ab_stage = 1
self.num_acc_stage = 1
self.num_c_stage = 2
# SMEM layouts
self.a_smem_layout_staged = utils.sm100.make_smem_layout_a(
tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage)
self.b_smem_layout_staged = utils.sm100.make_smem_layout_b(
tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage)
self.c_smem_layout_staged = utils.sm100.make_smem_layout_epi(
self.c_dtype, self.c_layout, self.epi_tile, self.num_c_stage)
# TMEM alloc cols
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake, arch="sm_100")
# TMA load bytes
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.a_dtype, a_smem_layout) +
cute.size_in_bytes(self.b_dtype, b_smem_layout)
) * cute.size(tiled_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor,
stream: cuda.CUstream):
self.a_dtype = a.element_type
self.b_dtype = b.element_type
self.c_dtype = c.element_type
self.a_major_mode = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major_mode = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
tiled_mma = self._create_tiled_mma()
self._setup_attributes()
# TMA load A
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id),
a, a_smem_layout, self.mma_tiler, tiled_mma,
self.cluster_layout_vmnk.shape,
)
# TMA load B
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id),
b, b_smem_layout, self.mma_tiler, tiled_mma,
self.cluster_layout_vmnk.shape,
)
# TMA store C
epi_smem_layout = cute.select(self.c_smem_layout_staged, mode=[0, 1])
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem_layout, self.epi_tile)
self._kernel(
tiled_mma, tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b,
tma_atom_c, tma_tensor_c, self.cluster_layout_vmnk,
self.a_smem_layout_staged, self.b_smem_layout_staged,
self.c_smem_layout_staged, self.epi_tile,
).launch(grid=(1, 1, 1), block=[self.threads_per_cta, 1, 1], stream=stream)
@cute.kernel
def _kernel(self, tiled_mma, tma_atom_a, mA_mkl, tma_atom_b, mB_nkl,
tma_atom_c, mC_mnl, cluster_layout_vmnk,
a_smem_layout_staged, b_smem_layout_staged, c_smem_layout_staged, epi_tile):
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
tidx, _, _ = cute.arch.thread_idx()
use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
is_leader_cta = True # single CTA, always leader
# Prefetch TMA descriptors
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_atom_a)
cpasync.prefetch_descriptor(tma_atom_b)
cpasync.prefetch_descriptor(tma_atom_c)
# ── Shared storage ───────────────────────────────────
@cute.struct
class SharedStorage:
ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc_mbar: cutlass.Int64
tmem_holding_buf: cutlass.Int32
smem = utils.SmemAllocator()
storage = smem.allocate(SharedStorage)
# AB pipeline
ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes,
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
).make_participants()
# ACC pipeline
acc_pipeline = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta_instrs else 1)),
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
)
# TMEM allocator
tmem_alloc_barrier = pipeline.NamedBarrier(
barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)),
)
tmem = utils.TmemAllocator(
storage.tmem_holding_buf.ptr,
barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=self.epilogue_warp_id[0],
is_two_cta=use_2cta_instrs,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
)
pipeline.pipeline_init_arrive(cluster_shape_mn=cluster_layout_vmnk, is_relaxed=True)
# SMEM tensors
sA = smem.allocate_tensor(
element_type=self.a_dtype, layout=a_smem_layout_staged.outer,
byte_alignment=128, swizzle=a_smem_layout_staged.inner)
sB = smem.allocate_tensor(
element_type=self.b_dtype, layout=b_smem_layout_staged.outer,
byte_alignment=128, swizzle=b_smem_layout_staged.inner)
sC = smem.allocate_tensor(
element_type=self.c_dtype, layout=c_smem_layout_staged.outer,
byte_alignment=128, swizzle=c_smem_layout_staged.inner)
# Partition global tensors
gA_mkl = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None))
gB_nkl = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None))
gC_mnl = cute.local_tile(mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None))
k_tile_cnt = cute.size(gA_mkl, mode=[3])
# Partition for TiledMMA
thr_mma = tiled_mma.get_slice(0) # leader CTA
tCgA = thr_mma.partition_A(gA_mkl)
tCgB = thr_mma.partition_B(gB_nkl)
tCgC = thr_mma.partition_C(gC_mnl)
# TMA partition A/B
a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape)
tAsA, tAgA = cpasync.tma_partition(
tma_atom_a, 0, a_cta_layout,
cute.group_modes(sA, 0, 3), cute.group_modes(tCgA, 0, 3))
b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape)
tBsB, tBgB = cpasync.tma_partition(
tma_atom_b, 0, b_cta_layout,
cute.group_modes(sB, 0, 3), cute.group_modes(tCgB, 0, 3))
# Slice to tile coord (0, 0, 0)
tAgA_slice = tAgA[(None, 0, None, 0)]
tBgB_slice = tBgB[(None, 0, None, 0)]
# MMA fragments
tCrA = tiled_mma.make_fragment_A(sA)
tCrB = tiled_mma.make_fragment_B(sB)
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cluster_layout_vmnk)
# ══════════════════════════════════════════════════════════
# TMA LOAD WARP (warp 5)
# ══════════════════════════════════════════════════════════
if warp_idx == self.tma_warp_id:
ab_producer.reset()
peek_ab_empty_status = ab_producer.try_acquire()
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
handle = ab_producer.acquire_and_advance(peek_ab_empty_status)
cute.copy(tma_atom_a, tAgA_slice[(None, handle.count)], tAsA[(None, handle.index)],
tma_bar_ptr=handle.barrier)
cute.copy(tma_atom_b, tBgB_slice[(None, handle.count)], tBsB[(None, handle.index)],
tma_bar_ptr=handle.barrier)
peek_ab_empty_status = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt:
peek_ab_empty_status = ab_producer.try_acquire()
ab_producer.tail()
# ══════════════════════════════════════════════════════════
# MMA WARP (warp 4)
# ══════════════════════════════════════════════════════════
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
tCtAcc = tCtAcc_base[(None, None, None, 0)]
ab_consumer.reset()
peek_ab_full_status = cutlass.Boolean(1)
if is_leader_cta:
peek_ab_full_status = ab_consumer.try_wait()
acc_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_acc_stage)
if is_leader_cta:
acc_pipeline.producer_acquire(acc_producer_state)
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
for k_tile in range(k_tile_cnt):
if is_leader_cta:
handle = ab_consumer.wait_and_advance(peek_ab_full_status)
num_kblocks = cute.size(tCrA, mode=[2])
for kblk_idx in cutlass.range(num_kblocks, unroll_full=True):
kblk_crd = (None, None, kblk_idx, handle.index)
cute.gemm(tiled_mma, tCtAcc, tCrA[kblk_crd], tCrB[kblk_crd], tCtAcc)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
handle.release()
peek_ab_full_status = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt:
peek_ab_full_status = ab_consumer.try_wait()
if is_leader_cta:
acc_pipeline.producer_commit(acc_producer_state)
acc_producer_state.advance()
acc_pipeline.producer_tail(acc_producer_state)
# ══════════════════════════════════════════════════════════
# EPILOGUE WARPS (0..3)
# ══════════════════════════════════════════════════════════
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
acc_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipeline = pipeline.PipelineTmaStore.create(
num_stages=self.num_c_stage, producer_group=c_producer_group)
# Use the reference epilogue implementation
mma_tile_coord_mnl = (0, 0, 0)
epilogue_op = const_expr(lambda x: x)
num_tiles_executed = 0
acc_consumer_state = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_atom_c, tCtAcc_base, sC, tCgC,
epi_tile, num_tiles_executed, epilogue_op,
mma_tile_coord_mnl, acc_consumer_state, acc_pipeline, c_pipeline)
c_pipeline.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test_stage_a():
"""Test Stage A: Q @ K^T → TMEM → GMEM"""
device = torch.device("cuda")
torch.manual_seed(42)
m, n, k = 128, 128, 512
# Tensors must be 3D (M, K, L) for the CUTLASS pattern
a = torch.randn(m, k, 1, dtype=torch.bfloat16, device="cuda")
b = torch.randn(n, k, 1, dtype=torch.bfloat16, device="cuda")
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device="cuda")
ref = a[:, :, 0].float() @ b[:, :, 0].float().T
# Create cute tensors
import cutlass.torch as cutlass_torch
mA = cutlass_torch.from_dlpack(a).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(a))
mB = cutlass_torch.from_dlpack(b).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(b))
mC = cutlass_torch.from_dlpack(c).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageAQKTKernel(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
compiled = cute.compile(kernel, mA, mB, mC, stream)
# Run with the same tensors
compiled(mA, mB, mC, stream)
torch.cuda.synchronize()
output = c[:, :, 0].float()
cos = torch.nn.functional.cosine_similarity(
output.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (output - ref).abs().max().item()
print("Stage A: Q({},{}) @ K^T({}, {}) -> S({}, {})".format(m, k, k, n, m, n))
print(" Cosine: {:.6f}, Max error: {:.6f}".format(cos, max_err))
print(" {}".format("PASS" if cos >= 0.99 else "FAIL"))
return cos
if __name__ == "__main__":
test_stage_a()

View File

@@ -0,0 +1,374 @@
"""
Stage A: Bare Q@K^T via tcgen05.mma → TMEM → GMEM
Follows the CUTLASS dense_gemm_persistent.py pattern EXACTLY.
BF16 inputs, FP32 accumulator, TMA load/store, warp specialization.
Single tile (no persistent scheduler), cluster (1,1).
"""
import torch
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.cute.runtime import make_ptr
import cuda.bindings.driver as cuda
class StageAQKTKernel:
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
self.acc_dtype = Float32
self.use_2cta_instrs = use_2cta_instrs
self.mma_tiler_mn = mma_tiler_mn
self.mma_tiler = (*mma_tiler_mn, 1)
self.use_tma_store = use_tma_store
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.threads_per_cta = 32 * 6 # 192
self.epilog_sync_bar_id = 1
self.tmem_alloc_sync_bar_id = 2
self.tmem_dealloc_sync_bar_id = 3
def _create_tiled_mma(self):
return utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.a_major_mode, self.b_major_mode,
self.acc_dtype, self.cta_group, self.mma_tiler_mn,
)
def _setup_attributes(self):
tiled_mma = self._create_tiled_mma()
pv_mma = utils.sm100.make_trivial_tiled_mma(self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major_mode, self.acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
mma_inst_tile_k = 4
self.mma_tiler = (self.mma_tiler[0], self.mma_tiler[1], mma_inst_shape_k * mma_inst_tile_k)
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler[1],
self.mma_tiler[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((1, 1, 1)), (tiled_mma.thr_id.shape,))
self.num_mcast_ctas_a = 1
self.num_mcast_ctas_b = 1
self.is_a_mcast = False
self.is_b_mcast = False
# Epilogue tile
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.c_dtype)
# Stage counts: 1 AB stage (single tile, no double-buffer), 1 acc stage, 2 C stages
self.num_ab_stage = 1
self.num_acc_stage = 1
self.num_c_stage = 2
# SMEM layouts
self.a_smem_layout_staged = utils.sm100.make_smem_layout_a(
tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage)
self.b_smem_layout_staged = utils.sm100.make_smem_layout_b(
tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage)
self.c_smem_layout_staged = utils.sm100.make_smem_layout_epi(
self.c_dtype, self.c_layout, self.epi_tile, self.num_c_stage)
# TMEM alloc cols
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake, arch="sm_100")
# TMA load bytes
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.a_dtype, a_smem_layout) +
cute.size_in_bytes(self.b_dtype, b_smem_layout)
) * cute.size(tiled_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor,
stream: cuda.CUstream):
self.a_dtype = a.element_type
self.b_dtype = b.element_type
self.c_dtype = c.element_type
self.a_major_mode = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major_mode = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
tiled_mma = self._create_tiled_mma()
pv_mma = utils.sm100.make_trivial_tiled_mma(self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major_mode, self.acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup_attributes()
# TMA load A
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id),
a, a_smem_layout, self.mma_tiler, tiled_mma,
self.cluster_layout_vmnk.shape,
)
# TMA load B
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id),
b, b_smem_layout, self.mma_tiler, tiled_mma,
self.cluster_layout_vmnk.shape,
)
# TMA store C
epi_smem_layout = cute.select(self.c_smem_layout_staged, mode=[0, 1])
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem_layout, self.epi_tile)
self._kernel(
tiled_mma, tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b,
tma_atom_c, tma_tensor_c, self.cluster_layout_vmnk,
self.a_smem_layout_staged, self.b_smem_layout_staged,
self.c_smem_layout_staged, self.epi_tile,
).launch(grid=(1, 1, 1), block=[self.threads_per_cta, 1, 1], stream=stream)
@cute.kernel
def _kernel(self, tiled_mma, tma_atom_a, mA_mkl, tma_atom_b, mB_nkl,
tma_atom_c, mC_mnl, cluster_layout_vmnk,
a_smem_layout_staged, b_smem_layout_staged, c_smem_layout_staged, epi_tile):
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
tidx, _, _ = cute.arch.thread_idx()
use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
is_leader_cta = True # single CTA, always leader
# Prefetch TMA descriptors
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_atom_a)
cpasync.prefetch_descriptor(tma_atom_b)
cpasync.prefetch_descriptor(tma_atom_c)
# ── Shared storage ───────────────────────────────────
@cute.struct
class SharedStorage:
ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc_mbar: cutlass.Int64
tmem_holding_buf: cutlass.Int32
smem = utils.SmemAllocator()
storage = smem.allocate(SharedStorage)
# AB pipeline
ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes,
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
).make_participants()
# ACC pipeline
acc_pipeline = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta_instrs else 1)),
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
)
# TMEM allocator
tmem_alloc_barrier = pipeline.NamedBarrier(
barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)),
)
tmem = utils.TmemAllocator(
storage.tmem_holding_buf.ptr,
barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=self.epilogue_warp_id[0],
is_two_cta=use_2cta_instrs,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
)
pipeline.pipeline_init_arrive(cluster_shape_mn=cluster_layout_vmnk, is_relaxed=True)
# SMEM tensors
sA = smem.allocate_tensor(
element_type=self.a_dtype, layout=a_smem_layout_staged.outer,
byte_alignment=128, swizzle=a_smem_layout_staged.inner)
sB = smem.allocate_tensor(
element_type=self.b_dtype, layout=b_smem_layout_staged.outer,
byte_alignment=128, swizzle=b_smem_layout_staged.inner)
sC = smem.allocate_tensor(
element_type=self.c_dtype, layout=c_smem_layout_staged.outer,
byte_alignment=128, swizzle=c_smem_layout_staged.inner)
# Partition global tensors
gA_mkl = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None))
gB_nkl = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None))
gC_mnl = cute.local_tile(mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None))
k_tile_cnt = cute.size(gA_mkl, mode=[3])
# Partition for TiledMMA
thr_mma = tiled_mma.get_slice(0) # leader CTA
tCgA = thr_mma.partition_A(gA_mkl)
tCgB = thr_mma.partition_B(gB_nkl)
tCgC = thr_mma.partition_C(gC_mnl)
# TMA partition A/B
a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape)
tAsA, tAgA = cpasync.tma_partition(
tma_atom_a, 0, a_cta_layout,
cute.group_modes(sA, 0, 3), cute.group_modes(tCgA, 0, 3))
b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape)
tBsB, tBgB = cpasync.tma_partition(
tma_atom_b, 0, b_cta_layout,
cute.group_modes(sB, 0, 3), cute.group_modes(tCgB, 0, 3))
# Slice to tile coord (0, 0, 0)
tAgA_slice = tAgA[(None, 0, None, 0)]
tBgB_slice = tBgB[(None, 0, None, 0)]
# MMA fragments
tCrA = tiled_mma.make_fragment_A(sA)
tCrB = tiled_mma.make_fragment_B(sB)
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cluster_layout_vmnk)
# ══════════════════════════════════════════════════════════
# TMA LOAD WARP (warp 5)
# ══════════════════════════════════════════════════════════
if warp_idx == self.tma_warp_id:
ab_producer.reset()
peek_ab_empty_status = ab_producer.try_acquire()
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
handle = ab_producer.acquire_and_advance(peek_ab_empty_status)
cute.copy(tma_atom_a, tAgA_slice[(None, handle.count)], tAsA[(None, handle.index)],
tma_bar_ptr=handle.barrier)
cute.copy(tma_atom_b, tBgB_slice[(None, handle.count)], tBsB[(None, handle.index)],
tma_bar_ptr=handle.barrier)
peek_ab_empty_status = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt:
peek_ab_empty_status = ab_producer.try_acquire()
ab_producer.tail()
# ══════════════════════════════════════════════════════════
# MMA WARP (warp 4)
# ══════════════════════════════════════════════════════════
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
tCtAcc = tCtAcc_base[(None, None, None, 0)]
ab_consumer.reset()
peek_ab_full_status = cutlass.Boolean(1)
if is_leader_cta:
peek_ab_full_status = ab_consumer.try_wait()
acc_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_acc_stage)
if is_leader_cta:
acc_pipeline.producer_acquire(acc_producer_state)
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
for k_tile in range(k_tile_cnt):
if is_leader_cta:
handle = ab_consumer.wait_and_advance(peek_ab_full_status)
num_kblocks = cute.size(tCrA, mode=[2])
for kblk_idx in cutlass.range(num_kblocks, unroll_full=True):
kblk_crd = (None, None, kblk_idx, handle.index)
cute.gemm(tiled_mma, tCtAcc, tCrA[kblk_crd], tCrB[kblk_crd], tCtAcc)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
handle.release()
peek_ab_full_status = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt:
peek_ab_full_status = ab_consumer.try_wait()
if is_leader_cta:
acc_pipeline.producer_commit(acc_producer_state)
acc_producer_state.advance()
acc_pipeline.producer_tail(acc_producer_state)
# ══════════════════════════════════════════════════════════
# EPILOGUE WARPS (0..3)
# ══════════════════════════════════════════════════════════
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
acc_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipeline = pipeline.PipelineTmaStore.create(
num_stages=self.num_c_stage, producer_group=c_producer_group)
# Use the reference epilogue implementation
mma_tile_coord_mnl = (0, 0, 0)
epilogue_op = const_expr(lambda x: x)
num_tiles_executed = 0
acc_consumer_state = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_atom_c, tCtAcc_base, sC, tCgC,
epi_tile, num_tiles_executed, epilogue_op,
mma_tile_coord_mnl, acc_consumer_state, acc_pipeline, c_pipeline)
c_pipeline.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test_stage_a():
"""Test Stage A: Q @ K^T → TMEM → GMEM"""
device = torch.device("cuda")
torch.manual_seed(42)
m, n, k = 128, 128, 512
# Tensors must be 3D (M, K, L) for the CUTLASS pattern
a = torch.randn(m, k, 1, dtype=torch.bfloat16, device="cuda")
b = torch.randn(n, k, 1, dtype=torch.bfloat16, device="cuda")
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device="cuda")
ref = a[:, :, 0].float() @ b[:, :, 0].float().T
# Create cute tensors
import cutlass.torch as cutlass_torch
mA = cutlass_torch.from_dlpack(a).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(a))
mB = cutlass_torch.from_dlpack(b).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(b))
mC = cutlass_torch.from_dlpack(c).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageAQKTKernel(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
compiled = cute.compile(kernel, mA, mB, mC, stream)
# Run with the same tensors
compiled(mA, mB, mC, stream)
torch.cuda.synchronize()
output = c[:, :, 0].float()
cos = torch.nn.functional.cosine_similarity(
output.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (output - ref).abs().max().item()
print("Stage A: Q({},{}) @ K^T({}, {}) -> S({}, {})".format(m, k, k, n, m, n))
print(" Cosine: {:.6f}, Max error: {:.6f}".format(cos, max_err))
print(" {}".format("PASS" if cos >= 0.99 else "FAIL"))
return cos
if __name__ == "__main__":
test_stage_a()

252
tests/test_stage_b_debug.py Normal file
View File

@@ -0,0 +1,252 @@
"""Stage B debug: Two MMAs with PipelineUmmaAsync sync, NO softmax.
Just Q@K^T then P@V reading from same TMEM (will be garbage, but should not deadlock)."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
class StageBDebug:
def __init__(self, mma_tiler_mn):
self.acc_dtype = Float32
self.qk_acc_dtype = Float32
self.q_dtype = BFloat16
self.o_dtype = BFloat16
self.mma_tiler_mn = mma_tiler_mn
self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.use_2cta_instrs = False
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.threads_per_cta = 192
self.epilog_sync_bar_id = 1
self.num_c_stage = 2
self.tmem_s0_offset = 0
self.tmem_o0_offset = 256
self.tmem_p0_offset = 32
self.tmem_alloc_cols = 512
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = tuple(self.qk_mma_tiler)
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.a_dtype, 1)
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.b_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
acc_shape = qk_mma.partition_shape_C(self.mma_tiler_mn)
tCtS_fake = qk_mma.make_fragment_C(cute.append(acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtS_fake, arch="sm_100")
q_smem = cute.slice_(self.q_smem_s, (None, None, None, 0))
k_smem = cute.slice_(self.k_smem_s, (None, None, None, 0))
self.num_tma_bytes = (cute.size_in_bytes(self.a_dtype, q_smem) + cute.size_in_bytes(self.b_dtype, k_smem)) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a, b, c, stream):
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major, self.b_major, self.acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major, self.acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
q_smem = cute.slice_(self.q_smem_s, (None, None, None, 0))
k_smem = cute.slice_(self.k_smem_s, (None, None, None, 0))
tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, q_smem, self.qk_mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, k_smem, self.qk_mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_c, tma_tc,
self.cluster_layout_vmnk, self.q_smem_s, self.k_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[192,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_c, mC, cl_vmnk,
q_smem_s, k_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator()
st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 128),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 128),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=160)
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=0, is_two_cta=False,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sQ = smem.allocate_tensor(element_type=BFloat16, layout=q_smem_s.outer, byte_alignment=128, swizzle=q_smem_s.inner)
sK = smem.allocate_tensor(element_type=BFloat16, layout=k_smem_s.outer, byte_alignment=128, swizzle=k_smem_s.inner)
sC = smem.allocate_tensor(element_type=BFloat16, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None))
gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.qk_mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gQ, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tAsK, tAgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tAgK = tAgK[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sK)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler_mn)
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_mma.partition_shape_C(self.mma_tiler_mn)
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_mma.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_k, tAgK[(None,h.count)], tAsK[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# MMA
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
# QK MMA
s0_handle = mma_si_prod.acquire_and_advance()
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrQ, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,h.index)], tCrK[(None,None,kb,h.index)], tStS0)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
# Re-acquire (wait for softmax to release)
s0_handle = mma_si_prod.acquire_and_advance()
# PV MMA (no softmax, P = raw scores in C-layout, PV reads as A-layout → garbage, but should not hang)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, 1)
acc_pipe.producer_acquire(acc_prod_st)
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# Softmax/Epilogue
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
# NO softmax — just pass through the pipeline
si_handle = mma_si_cons.wait_and_advance()
# ... do nothing (identity) ...
si_handle.release()
# Epilogue
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 128)
c_pipe = pipeline.PipelineTmaStore.create(num_stages=2, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m,n,k = 128,128,128
q = torch.randn(m,k,1,dtype=torch.bfloat16,device='cuda')
kv = torch.randn(n,k,1,dtype=torch.bfloat16,device='cuda')
c = torch.zeros(m,n,1,dtype=torch.bfloat16,device='cuda')
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBDebug((128,128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
print('Kernel completed without deadlock!')
out = c[:,:,0].float()
print('Output shape:', out.shape, 'nonzero:', (out != 0).sum().item())
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,224 @@
"""Stage B debug v2: Two MMAs, no pipeline between QK and PV.
Both on MMA warp, sequential. NO mma_si pipeline."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
class StageBDebug2:
def __init__(self, mma_tiler_mn):
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.use_2cta_instrs = False
self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.epilog_sync_bar_id = 1; self.num_c_stage = 2
self.tmem_s0_offset = 0; self.tmem_o0_offset = 256; self.tmem_p0_offset = 32
self.tmem_alloc_cols = 512
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = tuple(self.qk_mma_tiler)
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.a_dtype, 1)
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.b_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
acc_shape = qk_mma.partition_shape_C(self.mma_tiler_mn)
tCtS_fake = qk_mma.make_fragment_C(cute.append(acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtS_fake, arch="sm_100")
q_smem = cute.slice_(self.q_smem_s, (None, None, None, 0))
k_smem = cute.slice_(self.k_smem_s, (None, None, None, 0))
self.num_tma_bytes = (cute.size_in_bytes(self.a_dtype, q_smem) + cute.size_in_bytes(self.b_dtype, k_smem)) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a, b, c, stream):
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major, self.b_major, self.acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major, self.acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
q_smem = cute.slice_(self.q_smem_s, (None, None, None, 0))
k_smem = cute.slice_(self.k_smem_s, (None, None, None, 0))
tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, q_smem, self.qk_mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, k_smem, self.qk_mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_c, tma_tc,
self.cluster_layout_vmnk, self.q_smem_s, self.k_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[192,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_c, mC, cl_vmnk,
q_smem_s, k_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator()
st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 128),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=160)
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=0, is_two_cta=False,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sQ = smem.allocate_tensor(element_type=BFloat16, layout=q_smem_s.outer, byte_alignment=128, swizzle=q_smem_s.inner)
sK = smem.allocate_tensor(element_type=BFloat16, layout=k_smem_s.outer, byte_alignment=128, swizzle=k_smem_s.inner)
sC = smem.allocate_tensor(element_type=BFloat16, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None))
gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.qk_mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gQ, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tAsK, tAgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tAgK = tAgK[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sK)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler_mn)
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_mma.partition_shape_C(self.mma_tiler_mn)
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_mma.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_k, tAgK[(None,h.count)], tAsK[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# MMA — both QK and PV sequentially, no mma_si pipeline
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
# QK MMA
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrQ, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,h.index)], tCrK[(None,None,kb,h.index)], tStS0)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
# PV MMA — directly after QK, same warp
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, 1)
acc_pipe.producer_acquire(acc_prod_st)
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# Epilogue only (no softmax)
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 128)
c_pipe = pipeline.PipelineTmaStore.create(num_stages=2, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m,n,k = 128,128,128
q = torch.randn(m,k,1,dtype=torch.bfloat16,device='cuda')
kv = torch.randn(n,k,1,dtype=torch.bfloat16,device='cuda')
c = torch.zeros(m,n,1,dtype=torch.bfloat16,device='cuda')
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBDebug2((128,128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
print('No deadlock!')
out = c[:,:,0].float()
print('Output nonzero:', (out != 0).sum().item())
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,198 @@
"""Stage B debug v3: Same as Stage A but epilogue reads from PV MMA's C layout at offset 256.
Only does QK MMA (no PV). Tests whether the epilogue can handle a different accumulator layout."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
class StageBDebug3:
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.a_dtype = BFloat16; self.b_dtype = BFloat16; self.c_dtype = BFloat16
self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.epilog_sync_bar_id = 1
self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = tuple(self.qk_mma_tiler)
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.c_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1; self.num_c_stage = 2
self.a_smem_layout_staged = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
self.b_smem_layout_staged = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
self.c_smem_layout_staged = utils.sm100.make_smem_layout_epi(self.c_dtype, self.c_layout, self.epi_tile, 2)
# Use PV MMA's C fragment for the TMEM allocation
pv_acc_shape = pv_mma.partition_shape_C(self.mma_tiler_mn)
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtO_fake, arch="sm_100")
a_smem = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
self.num_tma_load_bytes = (cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
self.a_major_mode = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major_mode = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major_mode, self.b_major_mode,
self.acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major_mode,
self.acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
tma_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_layout_staged, mode=[0, 1])
tma_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_tensor_a, tma_b, tma_tensor_b,
tma_c, tma_tensor_c, self.cluster_layout_vmnk,
self.a_smem_layout_staged, self.b_smem_layout_staged, self.c_smem_layout_staged, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta, 1, 1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_atom_a, mA_mkl, tma_atom_b, mB_nkl,
tma_atom_c, mC_mnl, cluster_layout_vmnk,
a_smem_layout_staged, b_smem_layout_staged, c_smem_layout_staged, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
is_leader_cta = True
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_atom_a); cpasync.prefetch_descriptor(tma_atom_b); cpasync.prefetch_descriptor(tma_atom_c)
@cute.struct
class SharedStorage:
ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc_mbar: cutlass.Int64
tmem_holding_buf: cutlass.Int32
smem = utils.SmemAllocator(); storage = smem.allocate(SharedStorage)
ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cluster_layout_vmnk, defer_sync=True
).make_participants()
acc_pipeline = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cluster_layout_vmnk, defer_sync=True)
tmem_alloc_barrier = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(storage.tmem_holding_buf.ptr, barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cluster_layout_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_layout_staged.outer, byte_alignment=128, swizzle=a_smem_layout_staged.inner)
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_layout_staged.outer, byte_alignment=128, swizzle=b_smem_layout_staged.inner)
sC = smem.allocate_tensor(element_type=self.c_dtype, layout=c_smem_layout_staged.outer, byte_alignment=128, swizzle=c_smem_layout_staged.inner)
gA_mkl = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None))
gB_nkl = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None))
gC_mnl = cute.local_tile(mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None))
k_tile_cnt = cute.size(gA_mkl, mode=[3])
thr_mma = qk_mma.get_slice(0)
tCgA = thr_mma.partition_A(gA_mkl); tCgB = thr_mma.partition_B(gB_nkl); tCgC = thr_mma.partition_C(gC_mnl)
a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_atom_a, 0, a_cta_layout, cute.group_modes(sA, 0, 3), cute.group_modes(tCgA, 0, 3))
b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_atom_b, 0, b_cta_layout, cute.group_modes(sB, 0, 3), cute.group_modes(tCgB, 0, 3))
tAgA_slice = tAgA[(None, 0, None, 0)]; tBgB_slice = tBgB[(None, 0, None, 0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
# Use PV MMA's C fragment for the accumulator (test: can epilogue handle PV's C layout?)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_mma.partition_shape_C(self.mma_tiler_mn)
tCtAcc_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cluster_layout_vmnk)
# TMA
if warp_idx == self.tma_warp_id:
ab_producer.reset(); peek = ab_producer.try_acquire()
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
handle = ab_producer.acquire_and_advance(peek)
cute.copy(tma_atom_a, tAgA_slice[(None, handle.count)], tAsA[(None, handle.index)], tma_bar_ptr=handle.barrier)
cute.copy(tma_atom_b, tBgB_slice[(None, handle.count)], tBsB[(None, handle.index)], tma_bar_ptr=handle.barrier)
peek = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt: peek = ab_producer.try_acquire()
ab_producer.tail()
# MMA
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc(); tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
tCtAcc = tCtAcc_base[(None, None, None, 0)]
ab_consumer.reset(); peek = ab_consumer.try_wait()
acc_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipeline.producer_acquire(acc_producer_state)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for k_tile in range(k_tile_cnt):
if is_leader_cta:
handle = ab_consumer.wait_and_advance(peek)
num_kblocks = cute.size(tCrA, mode=[2])
for kblk_idx in cutlass.range(num_kblocks, unroll_full=True):
cute.gemm(qk_mma, tCtAcc, tCrA[(None, None, kblk_idx, handle.index)], tCrB[(None, None, kblk_idx, handle.index)], tCtAcc)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
handle.release(); peek = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt: peek = ab_consumer.try_wait()
acc_pipeline.producer_commit(acc_producer_state)
acc_producer_state.advance()
acc_pipeline.producer_tail(acc_producer_state)
# Epilogue
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols); tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
acc_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipeline = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_producer_group)
acc_consumer_state = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_atom_c, tCtAcc_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0, 0, 0), acc_consumer_state, acc_pipeline, c_pipeline)
c_pipeline.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 512
a = torch.randn(m, k, 1, dtype=torch.bfloat16, device="cuda")
b = torch.randn(n, k, 1, dtype=torch.bfloat16, device="cuda")
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device="cuda")
ref = a[:, :, 0].float() @ b[:, :, 0].float().T
import cutlass.torch as cutlass_torch
mA = cutlass_torch.from_dlpack(a).mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(a))
mB = cutlass_torch.from_dlpack(b).mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(b))
mC = cutlass_torch.from_dlpack(c).mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBDebug3(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
compiled = cute.compile(kernel, mA, mB, mC, stream)
compiled(mA, mB, mC, stream)
torch.cuda.synchronize()
output = c[:, :, 0].float()
cos = torch.nn.functional.cosine_similarity(output.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print("Cosine: {:.6f}".format(cos))
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,205 @@
"""Stage B debug v4: QK→PV sequential on MMA warp.
Uses QK MMA's C-fragment for both QK output and PV output.
PV writes to offset 0 (same as QK output) — garbage layout, but should not hang."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
class StageBDebug4:
def __init__(self, mma_tiler_mn):
self.acc_dtype = Float32; self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE; self.use_2cta_instrs = False
self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.epilog_sync_bar_id = 1; self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = tuple(self.qk_mma_tiler)
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, BFloat16)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, BFloat16, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, BFloat16, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(BFloat16, self.c_layout, self.epi_tile, 2)
# Use QK MMA's fragment for TMEM allocation
acc_shape = qk_mma.partition_shape_C(self.mma_tiler_mn)
tCtAcc_fake = qk_mma.make_fragment_C(cute.append(acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake, arch="sm_100")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (cute.size_in_bytes(BFloat16, a_smem) + cute.size_in_bytes(BFloat16, b_smem)) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a, b, c, stream):
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major, self.b_major, self.acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major, self.acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[192,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 128),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=160)
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=0, is_two_cta=False,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=BFloat16, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=BFloat16, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sC = smem.allocate_tensor(element_type=BFloat16, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
thr = qk_mma.get_slice(0)
tCgA = thr.partition_A(gA); tCgB = thr.partition_B(gB); tCgC = thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sB)
# Use QK MMA's C-fragment for EVERYTHING (like Stage A)
acc_shape = qk_mma.partition_shape_C(self.mma_tiler_mn)
tCtAcc_fake = qk_mma.make_fragment_C(cute.append(acc_shape, 1))
# Also need 2D fragment for MMA
tStS = thr.make_fragment_C(acc_shape)
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# MMA
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
tCtAcc = tCtAcc_base[(None,None,None,0)]
ab_c.reset(); peek = ab_c.try_wait()
# QK MMA (identical to Stage A)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tCtAcc, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tCtAcc)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
# NO PV MMA, NO extra operations
# Just signal output (same as Stage A)
acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, 1)
acc_pipe.producer_acquire(acc_st)
acc_pipe.producer_commit(acc_st)
acc_st.advance()
acc_pipe.producer_tail(acc_st)
# Epilogue (identical to Stage A)
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
cons = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 128)
c_pipe = pipeline.PipelineTmaStore.create(num_stages=2, producer_group=c_grp)
cons = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtAcc_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), cons, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m,n,k = 128,128,512
a = torch.randn(m,k,1,dtype=torch.bfloat16,device='cuda')
b = torch.randn(n,k,1,dtype=torch.bfloat16,device='cuda')
c = torch.zeros(m,n,1,dtype=torch.bfloat16,device='cuda')
ref = a[:,:,0].float() @ b[:,:,0].float().T
import cutlass.torch as ct
mA = ct.from_dlpack(a).mark_layout_dynamic(leading_dim=ct.get_leading_dim(a))
mB = ct.from_dlpack(b).mark_layout_dynamic(leading_dim=ct.get_leading_dim(b))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBDebug4((128,128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mA, mB, mC, stream)
print('Running...', flush=True)
compiled(mA, mB, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('Cosine: {:.6f} ({})'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,487 @@
"""
Stage B: Two MMAs + Identity Softmax with Layout Transform
Following NVIDIA's fmha.py softmax_step pattern exactly.
Architecture:
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
Identity softmax: tcgen05.ld from C-layout → convert F32→BF16 → tcgen05.st to A-layout
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
Reference: output = (Q @ K^T) @ V (no softmax, P = raw scores)
TMEM Layout (following fmha.py for 128x128 MMA tile):
tmem_s0 = 0 (scores, QK accumulator, C-layout)
tmem_p0 = 32 (P, PV A-operand, A-layout — written by identity softmax)
tmem_o0 = 256 (output, PV accumulator, C-layout)
The identity softmax performs the C-layout → A-layout transform via tcgen05.ld + tcgen05.st.
This is the critical bridge that makes the two-MMA pipeline work on Blackwell.
"""
import torch
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
class StageBIdentitySoftmaxKernel:
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
self.acc_dtype = Float32
self.qk_acc_dtype = Float32
self.q_dtype = BFloat16
self.o_dtype = BFloat16
self.use_2cta_instrs = use_2cta_instrs
self.mma_tiler_mn = mma_tiler_mn
self.mma_tiler = (*mma_tiler_mn, 1)
self.use_tma_store = use_tma_store
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
self.softmax_warp_ids = (0, 1, 2, 3)
self.epilogue_warp_id = self.softmax_warp_ids # same warps do softmax + epilogue
self.mma_warp_id = 4
self.tma_warp_id = 5
self.threads_per_cta = 32 * 6
# TMEM offsets (fmha.py pattern)
self.tmem_s0_offset = 0
self.tmem_o0_offset = 256
self.tmem_p0_offset = 32
self.tmem_alloc_cols = 512
self.epilog_sync_bar_id = 1
self.tmem_alloc_sync_bar_id = 2
self.tmem_dealloc_sync_bar_id = 3
self.scores_full_bar_id = 4
self.softmax_done_bar_id = 5
self.num_c_stage = 2
def _setup_attributes(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0], self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((1, 1, 1)), (qk_mma.thr_id.shape,))
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, self.use_2cta_instrs,
self.c_layout, self.o_dtype)
self.num_ab_stage = 1
self.num_acc_stage = 1
self.q_smem_layout_staged = utils.sm100.make_smem_layout_a(
qk_mma, self.qk_mma_tiler, self.a_dtype, self.num_ab_stage)
self.k_smem_layout_staged = utils.sm100.make_smem_layout_b(
qk_mma, self.qk_mma_tiler, self.b_dtype, self.num_ab_stage)
self.v_smem_layout_staged = utils.sm100.make_smem_layout_b(
pv_mma, self.pv_mma_tiler, self.b_dtype, self.num_ab_stage)
self.p_tmem_layout_staged = utils.sm100.make_smem_layout_a(
pv_mma, self.pv_mma_tiler, self.q_dtype, self.num_ab_stage)
self.c_smem_layout_staged = utils.sm100.make_smem_layout_epi(
self.o_dtype, self.c_layout, self.epi_tile, self.num_c_stage)
# For TMEM allocation
acc_shape_qk = qk_mma.partition_shape_C(self.mma_tiler_mn)
tCtS_fake = qk_mma.make_fragment_C(cute.append(acc_shape_qk, self.num_acc_stage))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtS_fake, arch="sm_100")
q_smem = cute.slice_(self.q_smem_layout_staged, (None, None, None, 0))
k_smem = cute.slice_(self.k_smem_layout_staged, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.a_dtype, q_smem) +
cute.size_in_bytes(self.b_dtype, k_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor,
stream: cuda.CUstream):
self.a_dtype = a.element_type
self.b_dtype = b.element_type
self.c_dtype = c.element_type
self.a_major_mode = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major_mode = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major_mode, self.b_major_mode,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype,
cute.nvgpu.OperandMajorMode.K, self.b_major_mode,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.TMEM)
self._setup_attributes(qk_mma, pv_mma)
q_smem = cute.slice_(self.q_smem_layout_staged, (None, None, None, 0))
k_smem = cute.slice_(self.k_smem_layout_staged, (None, None, None, 0))
tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, q_smem, self.qk_mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, k_smem, self.qk_mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_layout_staged, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(
qk_mma, pv_mma,
tma_q, tma_tq, tma_k, tma_tk,
tma_c, tma_tc,
self.cluster_layout_vmnk,
self.q_smem_layout_staged, self.k_smem_layout_staged,
self.v_smem_layout_staged, self.p_tmem_layout_staged,
self.c_smem_layout_staged, self.epi_tile,
).launch(grid=(1, 1, 1), block=[self.threads_per_cta, 1, 1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma,
tma_q, mQ, tma_k, mK,
tma_c, mC, cl_vmnk,
q_smem_staged, k_smem_staged,
v_smem_staged, p_tmem_staged,
c_smem_staged, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta_instrs = cute.size(qk_mma.thr_id.shape) == 2
is_leader_cta = True
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q)
cpasync.prefetch_descriptor(tma_k)
cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SharedStorage:
ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc_mbar: cutlass.Int64
tmem_holding_buf: cutlass.Int32
scores_full_mbar: cutlass.Int64
softmax_done_mbar: cutlass.Int64
smem = utils.SmemAllocator()
storage = smem.allocate(SharedStorage)
ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes,
cta_layout_vmnk=cl_vmnk,
defer_sync=True,
).make_participants()
acc_pipeline = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread,
32 * len(self.softmax_warp_ids) * (2 if use_2cta_instrs else 1)),
cta_layout_vmnk=cl_vmnk,
defer_sync=True,
)
tmem_alloc_barrier = pipeline.NamedBarrier(
barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.softmax_warp_ids)),
)
tmem = utils.TmemAllocator(
storage.tmem_holding_buf.ptr,
barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=self.softmax_warp_ids[0],
is_two_cta=use_2cta_instrs,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
)
scores_full_mbar = pipeline.NamedBarrier(
barrier_id=self.scores_full_bar_id,
num_threads=32 * (1 + len(self.softmax_warp_ids)),
)
softmax_done_mbar = pipeline.NamedBarrier(
barrier_id=self.softmax_done_bar_id,
num_threads=32 * (1 + len(self.softmax_warp_ids)),
)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sQ = smem.allocate_tensor(
element_type=self.a_dtype, layout=q_smem_staged.outer,
byte_alignment=128, swizzle=q_smem_staged.inner)
sK = smem.allocate_tensor(
element_type=self.b_dtype, layout=k_smem_staged.outer,
byte_alignment=128, swizzle=k_smem_staged.inner)
sC = smem.allocate_tensor(
element_type=self.o_dtype, layout=c_smem_staged.outer,
byte_alignment=128, swizzle=c_smem_staged.inner)
gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None, 0, None)), (None, None, None))
gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0, None, None)), (None, None, None))
gC = cute.local_tile(mC, cute.slice_(self.qk_mma_tiler, (None, None, 0)), (None, None, None))
k_tile_cnt = cute.size(gQ, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ)
tCgK = qk_thr.partition_B(gK)
tCgC = qk_thr.partition_C(gC)
a_cta_layout = cute.make_layout(cute.slice_(cl_vmnk, (0, 0, None, 0)).shape)
tAsQ, tAgQ = cpasync.tma_partition(
tma_q, 0, a_cta_layout,
cute.group_modes(sQ, 0, 3), cute.group_modes(tCgQ, 0, 3))
b_cta_layout = cute.make_layout(cute.slice_(cl_vmnk, (0, None, 0, 0)).shape)
tAsK, tAgK = cpasync.tma_partition(
tma_k, 0, b_cta_layout,
cute.group_modes(sK, 0, 3), cute.group_modes(tCgK, 0, 3))
tAgQ = tAgQ[(None, 0, None, 0)]
tAgK = tAgK[(None, 0, None, 0)]
tCrQ = qk_mma.make_fragment_A(sQ)
tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sK)
# ── TMEM tensor setup (following fmha.py) ──
# QK accumulator (scores) — 2D C-layout (fmha.py pattern)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler_mn)
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
# PV accumulator (output) — 2D C-layout
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_mma.partition_shape_C(self.mma_tiler_mn)
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# P fragment for PV MMA (a_source=TMEM, A-layout)
tP = cute.make_tensor(tStS.iterator, p_tmem_staged.outer)
tOrP_base = pv_mma.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout,
)
# Fake accumulators with stage dim (for epilogue_tma_store + TMEM allocation)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ══════════════════════════════════════════════════════════
# TMA LOAD WARP (warp 5)
# ══════════════════════════════════════════════════════════
if warp_idx == self.tma_warp_id:
ab_producer.reset()
peek_ab_empty_status = ab_producer.try_acquire()
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
handle = ab_producer.acquire_and_advance(peek_ab_empty_status)
cute.copy(tma_q, tAgQ[(None, handle.count)], tAsQ[(None, handle.index)],
tma_bar_ptr=handle.barrier)
cute.copy(tma_k, tAgK[(None, handle.count)], tAsK[(None, handle.index)],
tma_bar_ptr=handle.barrier)
peek_ab_empty_status = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt:
peek_ab_empty_status = ab_producer.try_acquire()
ab_producer.tail()
# ══════════════════════════════════════════════════════════
# MMA WARP (warp 4)
# ══════════════════════════════════════════════════════════
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_consumer.reset()
peek_ab_full_status = ab_consumer.try_wait()
# QK MMA: Q @ K^T → tmem_scores
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for k_tile in range(k_tile_cnt):
if is_leader_cta:
handle = ab_consumer.wait_and_advance(peek_ab_full_status)
num_kblocks = cute.size(tCrQ, mode=[2])
for kblk_idx in cutlass.range(num_kblocks, unroll_full=True):
kblk_crd = (None, None, kblk_idx, handle.index)
cute.gemm(qk_mma, tStS0, tCrQ[kblk_crd], tCrK[kblk_crd], tStS0)
handle.release()
peek_ab_full_status = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt:
peek_ab_full_status = ab_consumer.try_wait()
cute.arch.fence_view_async_tmem_store()
scores_full_mbar.arrive()
softmax_done_mbar.wait()
# PV MMA: P @ V → tmem_output
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
num_pv_kblocks = cute.size(tOrP0, mode=[2])
for kblk_idx in cutlass.range(num_pv_kblocks, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None, None, kblk_idx)],
tCrV_s[(None, None, kblk_idx)], tOtO0)
acc_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipeline.producer_acquire(acc_producer_state)
acc_pipeline.producer_commit(acc_producer_state)
acc_producer_state.advance()
acc_pipeline.producer_tail(acc_producer_state)
# ══════════════════════════════════════════════════════════
# SOFTMAX / EPILOGUE WARPS (0..3)
# ══════════════════════════════════════════════════════════
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
# ── Identity softmax: C-layout → A-layout transform ──
# 1. LOAD pipeline (reads from QK C-layout)
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)),
self.qk_acc_dtype,
)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
softmax_thread_idx = tidx % (32 * len(self.softmax_warp_ids))
thr_tmem_load = tiled_tmem_load.get_slice(softmax_thread_idx)
tTMEM_LOADtS = thr_tmem_load.partition_S(tStS0)
cS = cute.make_identity_tensor(
(self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_tmem_load.partition_D(tScS)
# 2. STORE pipeline (writes P in A-layout at tmem_p0_offset)
tilePlikeFP32 = self.qk_mma_tiler[1] // 32 * self.o_dtype.width
tStS_P_layout = cute.composition(
tStS.layout, cute.make_layout((128, tilePlikeFP32)))
tStS_P = cute.make_tensor(
tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)),
self.qk_acc_dtype,
)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
thr_tmem_store = tiled_tmem_store.get_slice(softmax_thread_idx)
tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P)
tScS_P_layout = cute.composition(
tScS.layout, cute.make_layout((128, tilePlikeFP32)))
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P)
# 3. Wait for scores
scores_full_mbar.wait()
# 4. Load scores from C-layout → registers
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# 5. IDENTITY: convert F32 → Q dtype, no softmax math
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
tTMEM_STORErS_x4_e = cute.make_tensor(
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
tTMEM_LOADrS.layout,
)
s_vec = tTMEM_LOADrS.load()
tTMEM_STORErS_x4_e.store(s_vec.to(self.q_dtype))
# 6. Store into A-layout (P region)
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
cute.arch.fence_view_async_tmem_store()
# 7. Signal MMA warp
softmax_done_mbar.arrive()
# ── Epilogue: write output to GMEM ──
tCtO_base = cute.make_tensor(
tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, 32 * len(self.softmax_warp_ids))
c_pipeline = pipeline.PipelineTmaStore.create(
num_stages=self.num_c_stage, producer_group=c_producer_group)
epilogue_op = const_expr(lambda x: x)
acc_consumer_state = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, epilogue_op, (0, 0, 0),
acc_consumer_state, acc_pipeline, c_pipeline)
c_pipeline.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test_stage_b_identity_softmax():
"""Test Stage B: (Q @ K^T) @ V with identity softmax layout transform"""
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device="cuda")
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device="cuda")
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device="cuda")
qf = q[:, :, 0].float()
kvf = kv[:, :, 0].float()
scores = qf @ kvf.T
ref = scores @ kvf
import cutlass.torch as cutlass_torch
mQ = cutlass_torch.from_dlpack(q).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(q))
mK = cutlass_torch.from_dlpack(kv).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(kv))
mC = cutlass_torch.from_dlpack(c).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBIdentitySoftmaxKernel(
mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
print("Compiling Stage B (identity softmax with layout transform)...", flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print("Running...", flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
output = c[:, :, 0].float()
cos = torch.nn.functional.cosine_similarity(
output.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (output - ref).abs().max().item()
print("Stage B: (Q @ K^T) @ V with identity softmax layout transform")
print(" Shape: Q({},{}), K/V({},{}), output({},{})".format(m, k, n, k, m, n))
print(" Cosine: {:.6f}, Max error: {:.6f}".format(cos, max_err))
print(" {}".format("PASS" if cos >= 0.99 else "FAIL"))
return cos
if __name__ == "__main__":
test_stage_b_identity_softmax()

View File

@@ -0,0 +1,271 @@
"""
Stage B Minimal: Two MMAs chained, NO softmax, NO pipeline between them.
QK MMA: Q @ K^T → tmem_scores (SMEM source)
PV MMA: P @ V → tmem_output (TMEM source, P = tmem_scores)
This tests ONLY the PV MMA with a_source=TMEM.
If this crashes, the bug is in the TMEM A-operand path of PV MMA itself.
If this works with wrong output, the PV MMA works but the softmax pipeline is broken.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
class StageBMinimal:
def __init__(self, mma_tiler_mn):
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.mma_tiler_mn = mma_tiler_mn
self.cta_group = tcgen05.CtaGroup.ONE
self.use_2cta_instrs = False; self.use_tma_store = True
self.epilog_sync_bar_id = 1
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.tmem_alloc_sync_bar_id = 2
self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.c_layout = LayoutEnum.ROW_MAJOR
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
# TMEM offsets — same as fmha.py
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32
self.tmem_o0_offset = 128
qk_acc_shape = qk_mma.get_slice(0).partition_shape_C(self.mma_tiler[:2])
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtS_fake, arch="sm_100")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A((1,1), qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B((1,1), qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, 1 * 2] # 1 AB stage
acc_bar: cute.struct.MemRange[cutlass.Int64, 1 * 2] # 1 acc stage
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
# AB pipeline (TMA load)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
# ACC pipeline (PV output → epilogue)
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sB) # V = same as K for our test
# TMEM tensors
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# P fragment for PV MMA (TMEM A-operand)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ── TMA WARP ──
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# ── MMA WARP: Two MMAs, NO softmax, NO pipeline between them ──
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, 1)
acc_pipe.producer_acquire(acc_prod_st)
# QK MMA: Q @ K^T → tmem_scores
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
# Fence TMEM writes from QK MMA
cute.arch.fence_view_async_tmem_store()
# PV MMA: P @ V → tmem_output (P is tmem_scores, same TMEM, different layout)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# ── EPILOGUE WARPS ──
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
# Reference: Q @ K^T @ V (identity softmax — just raw scores times V)
ref = qf @ kvf.T @ kvf
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBMinimal(mma_tiler_mn=(128, 128))
print('Compiling Stage B Minimal (two MMAs, no softmax)...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (out - ref).abs().max().item()
print('Stage B Minimal: (Q @ K^T) @ V (no softmax, no pipeline between MMAs)')
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL (may need softmax for correct P layout)'))
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,281 @@
"""
Stage B Pipeline-Only: Two MMAs with PipelineUmmaAsync between them,
but IDENTITY transform (no tcgen05.ld/st — P is just tmem_scores).
Tests whether the PipelineUmmaAsync barrier ordering causes the crash.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
class StageBPipelineOnly:
def __init__(self, mma_tiler_mn):
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.mma_tiler_mn = mma_tiler_mn
self.cta_group = tcgen05.CtaGroup.ONE
self.use_2cta_instrs = False; self.use_tma_store = True
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2
self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.c_layout = LayoutEnum.ROW_MAJOR
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32
self.tmem_o0_offset = 128
qk_acc_shape = qk_mma.get_slice(0).partition_shape_C(self.mma_tiler[:2])
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtS_fake, arch="sm_100")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A((1,1), qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B((1,1), qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, 2] # 1 stage
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] # 1 stage MMA↔softmax
acc_bar: cute.struct.MemRange[cutlass.Int64, 2] # 1 stage
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
# MMA↔softmax pipeline
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sB)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# P fragment for PV MMA (TMEM A-operand)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ── TMA WARP ──
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# ── MMA WARP ──
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
# 1. Acquire S0 (signal that we'll produce scores)
s0_handle = mma_si_prod.acquire_and_advance()
# 2. QK MMA
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, 1)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
# 3. Fence + release scores to epilogue
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
# 4. Re-acquire (wait for softmax "done" — but softmax is identity, just releases)
s0_handle = mma_si_prod.acquire_and_advance()
# 5. PV MMA
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# ── EPILOGUE WARPS: pipeline wait+release, NO tcgen05.ld/st ──
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
# Wait for scores, then immediately release (identity — no transform)
si_handle = mma_si_cons.wait_and_advance()
# NO tcgen05.ld, NO F32→BF16, NO tcgen05.st
si_handle.release()
# Epilogue
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
ref = qf @ kvf.T @ kvf
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBPipelineOnly(mma_tiler_mn=(128, 128))
print('Compiling Stage B Pipeline-Only...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
has_nan = torch.isnan(out).any().item()
cos = torch.nn.functional.cosine_similarity(out.nan_to_num().flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print(f' Has NaN: {has_nan}, Cosine (non-nan): {cos:.6f}')
print(' PASS (no crash) — pipeline works, NaN expected (no softmax transform)')
if __name__ == '__main__':
test()

420
tests/test_stage_b_v1.py Normal file
View File

@@ -0,0 +1,420 @@
"""
Stage B: Two MMAs + Identity Softmax via TMEM
Architecture:
MMA1: Q @ K^T -> tmem_scores (a_source=SMEM, accumulate=False)
Identity softmax: tcgen05.ld -> fill 1.0 -> tcgen05.st back to TMEM
MMA2: P @ V -> tmem_output (a_source=TMEM, accumulate=True)
Two barriers: scores_full (MMA->epi), softmax_done (epi->MMA)
Two TMEM regions: scores (offset 0), output (offset N)
With identity softmax (P = 1.0), output = ones(M,N) @ V = column sums of V tiled.
"""
import torch
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
class StageBKernel:
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
self.acc_dtype = Float32
self.use_2cta_instrs = use_2cta_instrs
self.mma_tiler_mn = mma_tiler_mn
self.mma_tiler = (*mma_tiler_mn, 1)
self.use_tma_store = use_tma_store
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.threads_per_cta = 32 * 6
self.epilog_sync_bar_id = 1
self.tmem_alloc_sync_bar_id = 2
self.tmem_dealloc_sync_bar_id = 3
self.scores_full_bar_id = 5
self.softmax_done_bar_id = 6
def _setup_attributes(self, tiled_mma1, tiled_mma2):
mma_inst_shape_k = cute.size(tiled_mma1.shape_mnk, mode=[2])
mma_inst_tile_k = 4
self.mma_tiler = (self.mma_tiler[0], self.mma_tiler[1],
mma_inst_shape_k * mma_inst_tile_k)
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(tiled_mma1.thr_id.shape),
self.mma_tiler[1],
self.mma_tiler[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((1, 1, 1)), (tiled_mma1.thr_id.shape,))
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.c_dtype)
self.num_ab_stage = 1
self.num_acc_stage = 1
self.num_c_stage = 2
self.a_smem_layout_staged = utils.sm100.make_smem_layout_a(
tiled_mma1, self.mma_tiler, self.a_dtype, self.num_ab_stage)
self.b_smem_layout_staged = utils.sm100.make_smem_layout_b(
tiled_mma1, self.mma_tiler, self.b_dtype, self.num_ab_stage)
self.c_smem_layout_staged = utils.sm100.make_smem_layout_epi(
self.c_dtype, self.c_layout, self.epi_tile, self.num_c_stage)
acc_shape = tiled_mma1.partition_shape_C(self.mma_tiler_mn)
tCtAcc_fake = tiled_mma1.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
self.num_tmem_cols_per_region = utils.get_num_tmem_alloc_cols(tCtAcc_fake, arch="sm_100")
total = self.num_tmem_cols_per_region * 2
self.total_tmem_cols = 256
if total > 256:
self.total_tmem_cols = 512
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.a_dtype, a_smem_layout) +
cute.size_in_bytes(self.b_dtype, b_smem_layout)
) * cute.size(tiled_mma1.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor,
stream: cuda.CUstream):
self.a_dtype = a.element_type
self.b_dtype = b.element_type
self.c_dtype = c.element_type
self.a_major_mode = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major_mode = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
tiled_mma1 = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major_mode, self.b_major_mode,
self.acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM,
)
tiled_mma2 = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major_mode, self.b_major_mode,
self.acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.TMEM,
)
self._setup_attributes(tiled_mma1, tiled_mma2)
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(
self.cluster_shape_mn, tiled_mma1.thr_id),
a, a_smem_layout, self.mma_tiler, tiled_mma1,
self.cluster_layout_vmnk.shape,
)
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(
self.cluster_shape_mn, tiled_mma1.thr_id),
b, b_smem_layout, self.mma_tiler, tiled_mma1,
self.cluster_layout_vmnk.shape,
)
epi_smem_layout = cute.select(self.c_smem_layout_staged, mode=[0, 1])
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem_layout, self.epi_tile)
self._kernel(
tiled_mma1, tiled_mma2,
tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b,
tma_atom_c, tma_tensor_c, self.cluster_layout_vmnk,
self.a_smem_layout_staged, self.b_smem_layout_staged,
self.c_smem_layout_staged, self.epi_tile,
).launch(grid=(1, 1, 1), block=[self.threads_per_cta, 1, 1], stream=stream)
@cute.kernel
def _kernel(self, tiled_mma1, tiled_mma2,
tma_atom_a, mA_mkl, tma_atom_b, mB_nkl,
tma_atom_c, mC_mnl, cluster_layout_vmnk,
a_smem_layout_staged, b_smem_layout_staged,
c_smem_layout_staged, epi_tile):
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
tidx, _, _ = cute.arch.thread_idx()
use_2cta_instrs = cute.size(tiled_mma1.thr_id.shape) == 2
is_leader_cta = True
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_atom_a)
cpasync.prefetch_descriptor(tma_atom_b)
cpasync.prefetch_descriptor(tma_atom_c)
@cute.struct
class SharedStorage:
ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc_mbar: cutlass.Int64
tmem_holding_buf: cutlass.Int32
smem = utils.SmemAllocator()
storage = smem.allocate(SharedStorage)
ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes,
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
).make_participants()
acc_pipeline = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta_instrs else 1)),
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
)
tmem_alloc_barrier = pipeline.NamedBarrier(
barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)),
)
tmem = utils.TmemAllocator(
storage.tmem_holding_buf.ptr,
barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=self.epilogue_warp_id[0],
is_two_cta=use_2cta_instrs,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
)
# MMA <-> softmax handshake barriers
scores_full_barrier = pipeline.NamedBarrier(
barrier_id=self.scores_full_bar_id,
num_threads=32 * (1 + len(self.epilogue_warp_id)),
)
softmax_done_barrier = pipeline.NamedBarrier(
barrier_id=self.softmax_done_bar_id,
num_threads=32 * (len(self.epilogue_warp_id) + 1),
)
pipeline.pipeline_init_arrive(cluster_shape_mn=cluster_layout_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(
element_type=self.a_dtype, layout=a_smem_layout_staged.outer,
byte_alignment=128, swizzle=a_smem_layout_staged.inner)
sB = smem.allocate_tensor(
element_type=self.b_dtype, layout=b_smem_layout_staged.outer,
byte_alignment=128, swizzle=b_smem_layout_staged.inner)
sC = smem.allocate_tensor(
element_type=self.c_dtype, layout=c_smem_layout_staged.outer,
byte_alignment=128, swizzle=c_smem_layout_staged.inner)
gA_mkl = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None))
gB_nkl = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None))
gC_mnl = cute.local_tile(mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None))
k_tile_cnt = cute.size(gA_mkl, mode=[3])
thr_mma1 = tiled_mma1.get_slice(0)
tCgA = thr_mma1.partition_A(gA_mkl)
tCgB = thr_mma1.partition_B(gB_nkl)
tCgC = thr_mma1.partition_C(gC_mnl)
a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape)
tAsA, tAgA = cpasync.tma_partition(
tma_atom_a, 0, a_cta_layout,
cute.group_modes(sA, 0, 3), cute.group_modes(tCgA, 0, 3))
b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape)
tBsB, tBgB = cpasync.tma_partition(
tma_atom_b, 0, b_cta_layout,
cute.group_modes(sB, 0, 3), cute.group_modes(tCgB, 0, 3))
tAgA_slice = tAgA[(None, 0, None, 0)]
tBgB_slice = tBgB[(None, 0, None, 0)]
tCrA = tiled_mma1.make_fragment_A(sA)
tCrB = tiled_mma1.make_fragment_B(sB)
tCrA_mma2 = tiled_mma2.make_fragment_A(sA)
tCrB_mma2 = tiled_mma2.make_fragment_B(sB)
acc_shape = tiled_mma1.partition_shape_C(self.mma_tiler_mn)
tCtAcc_fake = tiled_mma1.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cluster_layout_vmnk)
# TMA LOAD WARP
if warp_idx == self.tma_warp_id:
ab_producer.reset()
peek_ab_empty_status = ab_producer.try_acquire()
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
handle = ab_producer.acquire_and_advance(peek_ab_empty_status)
cute.copy(tma_atom_a, tAgA_slice[(None, handle.count)], tAsA[(None, handle.index)],
tma_bar_ptr=handle.barrier)
cute.copy(tma_atom_b, tBgB_slice[(None, handle.count)], tBsB[(None, handle.index)],
tma_bar_ptr=handle.barrier)
peek_ab_empty_status = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt:
peek_ab_empty_status = ab_producer.try_acquire()
ab_producer.tail()
# MMA WARP
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtScores_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
tCtScores = tCtScores_base[(None, None, None, 0)]
output_tmem_ptr = cute.recast_ptr(
tmem_ptr + self.num_tmem_cols_per_region, dtype=self.acc_dtype)
tCtOutput_base = cute.make_tensor(output_tmem_ptr, tCtAcc_fake.layout)
tCtOutput = tCtOutput_base[(None, None, None, 0)]
ab_consumer.reset()
peek_ab_full_status = cutlass.Boolean(1)
if is_leader_cta:
peek_ab_full_status = ab_consumer.try_wait()
# MMA1: Q @ K^T -> tmem_scores
tiled_mma1.set(tcgen05.Field.ACCUMULATE, False)
for k_tile in range(k_tile_cnt):
if is_leader_cta:
handle = ab_consumer.wait_and_advance(peek_ab_full_status)
num_kblocks = cute.size(tCrA, mode=[2])
for kblk_idx in cutlass.range(num_kblocks, unroll_full=True):
kblk_crd = (None, None, kblk_idx, handle.index)
cute.gemm(tiled_mma1, tCtScores, tCrA[kblk_crd], tCrB[kblk_crd], tCtScores)
handle.release()
peek_ab_full_status = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt:
peek_ab_full_status = ab_consumer.try_wait()
# Signal scores ready
scores_full_barrier.arrive()
# Wait for softmax done
softmax_done_barrier.arrive_and_wait()
# MMA2: P @ V -> tmem_output
tiled_mma2.set(tcgen05.Field.ACCUMULATE, True)
num_kblocks_mma2 = cute.size(tCrB_mma2, mode=[2])
for kblk_idx in cutlass.range(num_kblocks_mma2, unroll_full=True):
kblk_crd = (None, None, kblk_idx, 0)
cute.gemm(tiled_mma2, tCtOutput, tCrA_mma2[kblk_crd], tCrB_mma2[kblk_crd], tCtOutput)
acc_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_acc_stage)
if is_leader_cta:
acc_pipeline.producer_acquire(acc_producer_state)
acc_pipeline.producer_commit(acc_producer_state)
acc_producer_state.advance()
acc_pipeline.producer_tail(acc_producer_state)
# EPILOGUE WARPS
if warp_idx < self.mma_warp_id:
tmem.allocate(self.total_tmem_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtScores_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
tCtScores = tCtScores_base[(None, None, None, 0)]
output_tmem_ptr = cute.recast_ptr(
tmem_ptr + self.num_tmem_cols_per_region, dtype=self.acc_dtype)
tCtOutput_base = cute.make_tensor(output_tmem_ptr, tCtAcc_fake.layout)
# Wait for scores
scores_full_barrier.arrive_and_wait()
# Identity softmax: load, fill 1.0, store back
tiled_copy_t2r, tTR_tScores, tTR_rScores = utils.gemm.sm100.epilogue_tmem_copy_and_partition(
self, tidx, tCtScores, tCgC, epi_tile, self.use_2cta_instrs)
cute.copy(tiled_copy_t2r, tTR_tScores, tTR_rScores)
for idx in cutlass.range(cute.size(tTR_rScores), unroll_full=True):
tTR_rScores[idx] = self.acc_dtype(1.0)
copy_atom_r2t = cute.make_copy_atom(
tcgen05.St16x128bOp(tcgen05.Repetition.x32, tcgen05.Unpack.NONE),
self.acc_dtype,
)
tiled_copy_r2t = tcgen05.make_tmem_copy(copy_atom_r2t, tCtScores)
thr_copy_r2t = tiled_copy_r2t.get_slice(tidx)
tRT_tScores = thr_copy_r2t.partition_D(tCtScores)
tRT_rP = thr_copy_r2t.partition_S(tTR_rScores)
cute.copy(tiled_copy_r2t, tRT_rP, tRT_tScores)
cute.arch.fence_view_async_tmem_load()
# Signal softmax done
softmax_done_barrier.arrive()
# Store output
acc_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipeline = pipeline.PipelineTmaStore.create(
num_stages=self.num_c_stage, producer_group=c_producer_group)
mma_tile_coord_mnl = (0, 0, 0)
epilogue_op = const_expr(lambda x: x)
num_tiles_executed = 0
acc_consumer_state = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_atom_c, tCtOutput_base, sC, tCgC,
epi_tile, num_tiles_executed, epilogue_op,
mma_tile_coord_mnl, acc_consumer_state, acc_pipeline, c_pipeline)
c_pipeline.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test_stage_b():
device = torch.device("cuda")
torch.manual_seed(42)
m, n, k = 128, 128, 128
a = torch.randn(m, k, 1, dtype=torch.bfloat16, device="cuda")
b = torch.randn(n, k, 1, dtype=torch.bfloat16, device="cuda")
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device="cuda")
# Identity softmax: P = 1.0, output = ones(M,N) @ V(N,K)
v = b[:, :, 0].float() # (128, 128)
ref = torch.ones(m, n, dtype=torch.float32) @ v
import cutlass.torch as cutlass_torch
mA = cutlass_torch.from_dlpack(a).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(a))
mB = cutlass_torch.from_dlpack(b).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(b))
mC = cutlass_torch.from_dlpack(c).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBKernel(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
compiled = cute.compile(kernel, mA, mB, mC, stream)
compiled(mA, mB, mC, stream)
torch.cuda.synchronize()
output = c[:, :, 0].float()
cos = torch.nn.functional.cosine_similarity(
output.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (output - ref).abs().max().item()
print("Stage B: Q @ K^T -> identity_softmax(P=1) @ V -> output")
print(" Cosine: {:.6f}, Max error: {:.6f}".format(cos, max_err))
print(" {}".format("PASS" if cos >= 0.99 else "FAIL"))
return cos
if __name__ == "__main__":
test_stage_b()

407
tests/test_stage_b_v2.py Normal file
View File

@@ -0,0 +1,407 @@
"""
Stage B v2: Two MMAs (Q@K^T then Scores@V) — no softmax, no identity P.
Tests:
- MMA1: Q @ K^T → tmem_scores (accumulate=False)
- MMA2: tmem_scores @ V → tmem_output (a_source=TMEM, accumulate=True)
- Two TMEM regions with pointer arithmetic
- The epilogue warps allocate TMEM and store output, but skip softmax for now
Reference: output = Q @ K^T @ V (with K=V for simplicity)
"""
import torch
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
class StageBKernel:
def __init__(self, mma_tiler_mn, use_2cta_instrs=False):
self.acc_dtype = Float32
self.use_2cta_instrs = use_2cta_instrs
self.mma_tiler_mn = mma_tiler_mn
self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
# Warp layout: 4 epilogue + 1 MMA + 1 TMA = 6 warps = 192 threads
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.threads_per_cta = 192
self.epilog_sync_bar_id = 1
self.tmem_alloc_sync_bar_id = 2
def _setup_attributes(self, tiled_mma1, tiled_mma2, a_dtype, b_dtype, c_dtype,
a_major, b_major, c_layout):
mma_inst_shape_k = cute.size(tiled_mma1.shape_mnk, mode=[2])
self.mma_tiler = (*self.mma_tiler_mn, mma_inst_shape_k * 4)
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(tiled_mma1.thr_id.shape),
self.mma_tiler[1],
self.mma_tiler[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((1, 1, 1)), (tiled_mma1.thr_id.shape,))
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, self.use_2cta_instrs, c_layout, c_dtype)
self.num_ab_stage = 1
self.num_acc_stage = 1
self.num_c_stage = 2
self.a_smem_layout_staged = utils.sm100.make_smem_layout_a(
tiled_mma1, self.mma_tiler, a_dtype, self.num_ab_stage)
self.b_smem_layout_staged = utils.sm100.make_smem_layout_b(
tiled_mma1, self.mma_tiler, b_dtype, self.num_ab_stage)
self.c_smem_layout_staged = utils.sm100.make_smem_layout_epi(
c_dtype, c_layout, self.epi_tile, self.num_c_stage)
# TMEM: two regions (scores + output), each partition_shape_C columns
acc_shape = tiled_mma1.partition_shape_C(self.mma_tiler_mn)
tCtAcc_fake = tiled_mma1.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
self.num_tmem_cols_per_region = utils.get_num_tmem_alloc_cols(tCtAcc_fake, arch="sm_100")
self.total_tmem_cols = max(self.num_tmem_cols_per_region * 2, 256)
a_smem = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(a_dtype, a_smem) +
cute.size_in_bytes(b_dtype, b_smem)
) * cute.size(tiled_mma1.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor,
stream: cuda.CUstream):
a_dtype = a.element_type
b_dtype = b.element_type
c_dtype = c.element_type
a_major = LayoutEnum.from_tensor(a).mma_major_mode()
b_major = LayoutEnum.from_tensor(b).mma_major_mode()
c_layout = LayoutEnum.from_tensor(c)
tiled_mma1 = utils.sm100.make_trivial_tiled_mma(
a_dtype, b_dtype, a_major, b_major,
self.acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM,
)
tiled_mma2 = utils.sm100.make_trivial_tiled_mma(
a_dtype, b_dtype, a_major, b_major,
self.acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM,
)
self._setup_attributes(tiled_mma1, tiled_mma2, a_dtype, b_dtype, c_dtype,
a_major, b_major, c_layout)
# These are needed by epilogue_tma_store which accesses them via self
self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.c_dtype = c_dtype
self.a_major_mode = a_major
self.b_major_mode = b_major
self.c_layout = c_layout
a_smem = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
tma_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(
self.cluster_shape_mn, tiled_mma1.thr_id),
a, a_smem, self.mma_tiler, tiled_mma1,
self.cluster_layout_vmnk.shape,
)
tma_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(
self.cluster_shape_mn, tiled_mma1.thr_id),
b, b_smem, self.mma_tiler, tiled_mma1,
self.cluster_layout_vmnk.shape,
)
epi_smem = cute.select(self.c_smem_layout_staged, mode=[0, 1])
tma_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(
tiled_mma1, tiled_mma2,
tma_a, tma_tensor_a, tma_b, tma_tensor_b,
tma_c, tma_tensor_c, self.cluster_layout_vmnk,
self.a_smem_layout_staged, self.b_smem_layout_staged,
self.c_smem_layout_staged, self.epi_tile,
).launch(grid=(1, 1, 1), block=[self.threads_per_cta, 1, 1], stream=stream)
@cute.kernel
def _kernel(self, tiled_mma1, tiled_mma2,
tma_atom_a, mA_mkl, tma_atom_b, mB_nkl,
tma_atom_c, mC_mnl, cluster_layout_vmnk,
a_smem_layout_staged, b_smem_layout_staged,
c_smem_layout_staged, epi_tile):
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
tidx, _, _ = cute.arch.thread_idx()
use_2cta_instrs = cute.size(tiled_mma1.thr_id.shape) == 2
is_leader_cta = True
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_atom_a)
cpasync.prefetch_descriptor(tma_atom_b)
cpasync.prefetch_descriptor(tma_atom_c)
@cute.struct
class SharedStorage:
ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc_mbar: cutlass.Int64
tmem_holding_buf: cutlass.Int32
smem = utils.SmemAllocator()
storage = smem.allocate(SharedStorage)
ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes,
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
).make_participants()
acc_pipeline = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta_instrs else 1)),
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
)
tmem_alloc_barrier = pipeline.NamedBarrier(
barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)),
)
tmem = utils.TmemAllocator(
storage.tmem_holding_buf.ptr,
barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=self.epilogue_warp_id[0],
is_two_cta=use_2cta_instrs,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
)
pipeline.pipeline_init_arrive(cluster_shape_mn=cluster_layout_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(
element_type=BFloat16,
layout=a_smem_layout_staged.outer,
byte_alignment=128, swizzle=a_smem_layout_staged.inner)
sB = smem.allocate_tensor(
element_type=BFloat16,
layout=b_smem_layout_staged.outer,
byte_alignment=128, swizzle=b_smem_layout_staged.inner)
sC = smem.allocate_tensor(
element_type=BFloat16,
layout=c_smem_layout_staged.outer,
byte_alignment=128, swizzle=c_smem_layout_staged.inner)
gA_mkl = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None))
gB_nkl = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None))
gC_mnl = cute.local_tile(mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None))
k_tile_cnt = cute.size(gA_mkl, mode=[3])
thr_mma1 = tiled_mma1.get_slice(0)
tCgA = thr_mma1.partition_A(gA_mkl)
tCgB = thr_mma1.partition_B(gB_nkl)
tCgC = thr_mma1.partition_C(gC_mnl)
a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape)
tAsA, tAgA = cpasync.tma_partition(
tma_atom_a, 0, a_cta_layout,
cute.group_modes(sA, 0, 3), cute.group_modes(tCgA, 0, 3))
b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape)
tBsB, tBgB = cpasync.tma_partition(
tma_atom_b, 0, b_cta_layout,
cute.group_modes(sB, 0, 3), cute.group_modes(tCgB, 0, 3))
tAgA_slice = tAgA[(None, 0, None, 0)]
tBgB_slice = tBgB[(None, 0, None, 0)]
# MMA1 fragments
tCrA = tiled_mma1.make_fragment_A(sA)
tCrB = tiled_mma1.make_fragment_B(sB)
# MMA2 fragment for B (V) - same SMEM as K
tCrB_mma2 = tiled_mma2.make_fragment_B(sB)
# MMA2 fragment for A (a_source=SMEM, same as MMA1)
tCrA_mma2 = tiled_mma2.make_fragment_A(sA)
# TMEM accumulator layout (same for both MMAs)
acc_shape = tiled_mma1.partition_shape_C(self.mma_tiler_mn)
tCtAcc_fake = tiled_mma1.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cluster_layout_vmnk)
# ── TMA LOAD WARP (warp 5) ──
if warp_idx == self.tma_warp_id:
ab_producer.reset()
peek = ab_producer.try_acquire()
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
handle = ab_producer.acquire_and_advance(peek)
cute.copy(tma_atom_a, tAgA_slice[(None, handle.count)], tAsA[(None, handle.index)],
tma_bar_ptr=handle.barrier)
cute.copy(tma_atom_b, tBgB_slice[(None, handle.count)], tBsB[(None, handle.index)],
tma_bar_ptr=handle.barrier)
peek = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt:
peek = ab_producer.try_acquire()
ab_producer.tail()
# ── MMA WARP (warp 4) ──
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
# TMEM region 0: scores (Q @ K^T)
tCtScores_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
tCtScores = tCtScores_base[(None, None, None, 0)]
# MMA2 A fragment from SMEM (a_source=SMEM)
tCrA_mma2 = tiled_mma2.make_fragment_A(sA)
# TMEM region 1: output (Scores @ V)
output_ptr = cute.recast_ptr(
tmem_ptr + self.num_tmem_cols_per_region, dtype=self.acc_dtype)
tCtOutput_base = cute.make_tensor(output_ptr, tCtAcc_fake.layout)
tCtOutput = tCtOutput_base[(None, None, None, 0)]
ab_consumer.reset()
peek = cutlass.Boolean(1)
if is_leader_cta:
peek = ab_consumer.try_wait()
# ── MMA1: Q @ K^T → tmem_scores ──
tiled_mma1.set(tcgen05.Field.ACCUMULATE, False)
for k_tile in range(k_tile_cnt):
if is_leader_cta:
handle = ab_consumer.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kblk in cutlass.range(nblk, unroll_full=True):
crd = (None, None, kblk, handle.index)
cute.gemm(tiled_mma1, tCtScores, tCrA[crd], tCrB[crd], tCtScores)
handle.release()
peek = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt:
peek = ab_consumer.try_wait()
# MMA1 done, scores in tmem_scores
# Fence to ensure TMEM writes are visible
cute.arch.fence_view_async_tmem_store()
# ── MMA2: Scores @ V → tmem_output ──
# a_source=TMEM: the MMA instruction reads A from TMEM
# The A operand in TMEM is the scores from MMA1
# We need to set the TMEM pointer for A using tiled_mma2.set()
# The C operand's TMEM pointer tells MMA2 where to write the output
# The A operand's TMEM pointer needs to be set separately
tiled_mma2.set(tcgen05.Field.ACCUMULATE, True)
# Set the A TMEM pointer to the scores region
# When a_source=TMEM, the MMA reads A from the C operand's base
# So we need tCtOutput to have the same base as tCtScores for MMA2
# This means we can't use two separate TMEM regions for this approach
# Instead, let's try passing the TMEM A operand via the auxiliary list
# cute.gemm(tiled_mma2, D, [A_main, A_tmem], B, C)
# Or use tiled_mma2.set(Field.SFA, tCtScores.iterator) to set A ptr
nblk2 = cute.size(tCrB_mma2, mode=[2])
for kblk in cutlass.range(nblk2, unroll_full=True):
crd = (None, None, kblk, 0)
cute.gemm(tiled_mma2, tCtOutput, tCrA_mma2[crd], tCrB_mma2[crd], tCtOutput)
# Signal output ready
acc_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_acc_stage)
if is_leader_cta:
acc_pipeline.producer_acquire(acc_producer_state)
acc_pipeline.producer_commit(acc_producer_state)
acc_producer_state.advance()
acc_pipeline.producer_tail(acc_producer_state)
# ── EPILOGUE WARPS (0..3) ──
if warp_idx < self.mma_warp_id:
tmem.allocate(self.total_tmem_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
# TMEM region 1: output
output_ptr = cute.recast_ptr(
tmem_ptr + self.num_tmem_cols_per_region, dtype=self.acc_dtype)
tCtOutput_base = cute.make_tensor(output_ptr, tCtAcc_fake.layout)
# Wait for MMA2 to finish
acc_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipeline = pipeline.PipelineTmaStore.create(
num_stages=self.num_c_stage, producer_group=c_producer_group)
mma_tile_coord_mnl = (0, 0, 0)
epilogue_op = const_expr(lambda x: x)
num_tiles_executed = 0
acc_consumer_state = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_atom_c, tCtOutput_base, sC, tCgC,
epi_tile, num_tiles_executed, epilogue_op,
mma_tile_coord_mnl, acc_consumer_state, acc_pipeline, c_pipeline)
c_pipeline.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test_stage_b():
torch.manual_seed(42)
m, n, k = 128, 128, 128
a = torch.randn(m, k, 1, dtype=torch.bfloat16, device="cuda")
b = torch.randn(n, k, 1, dtype=torch.bfloat16, device="cuda")
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device="cuda")
# Reference: Q @ K^T @ V (no softmax, K=V=b)
q = a[:, :, 0].float()
kt = b[:, :, 0].float().T # K^T
v = b[:, :, 0].float()
ref = q @ kt @ v # (128, 128)
import cutlass.torch as cutlass_torch
mA = cutlass_torch.from_dlpack(a).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(a))
mB = cutlass_torch.from_dlpack(b).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(b))
mC = cutlass_torch.from_dlpack(c).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBKernel(mma_tiler_mn=(128, 128))
print("Compiling...", flush=True)
compiled = cute.compile(kernel, mA, mB, mC, stream)
print("Running...", flush=True)
compiled(mA, mB, mC, stream)
torch.cuda.synchronize()
output = c[:, :, 0].float()
cos = torch.nn.functional.cosine_similarity(
output.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (output - ref).abs().max().item()
print("Stage B v2: Q @ K^T @ V (no softmax)")
print(" Cosine: {:.6f}, Max error: {:.6f}".format(cos, max_err))
print(" {}".format("PASS" if cos >= 0.99 else "FAIL"))
return cos
if __name__ == "__main__":
test_stage_b()

375
tests/test_stage_b_v3.py Normal file
View File

@@ -0,0 +1,375 @@
"""
Stage B: Two MMAs (Q@K^T then P@V) with a_source=TMEM for MMA2.
Architecture (following NVIDIA's fmha.py reference):
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
P fragment: constructed with P TMEM layout from make_smem_layout_A(pv_tiled_mma, ...)
P TMEM address: fragment.iterator + (acc_width/a_width) * tmem_p_offset
Reference: output = Q @ K^T @ V (no softmax, P = raw scores)
"""
import torch
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
class StageBKernel:
def __init__(self, mma_tiler_mn, use_2cta_instrs=False):
self.acc_dtype = Float32
self.use_2cta_instrs = use_2cta_instrs
self.mma_tiler_mn = mma_tiler_mn
self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.threads_per_cta = 192
self.epilog_sync_bar_id = 1
def _setup(self, qk_mma, pv_mma, a_dtype, b_dtype, c_dtype, a_major, b_major, c_layout):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.c_dtype = c_dtype
self.c_layout = c_layout
self.use_2cta_instrs = False
# QK MMA tiler
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
# PV MMA tiler (same M,N but potentially different K)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
# Use QK tiler for the overall tiler (A/B come from QK layout)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0], self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
# Epilogue tile from PV MMA (output is O, same shape as PV MMA's C)
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, False, c_layout, c_dtype)
self.num_ab_stage = 1
self.num_acc_stage = 1
self.num_c_stage = 2
# SMEM layouts for QK MMA (Q and K)
self.q_smem_layout = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, a_dtype, 1)
self.k_smem_layout = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, b_dtype, 1)
# SMEM layout for V (from PV MMA)
self.v_smem_layout = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, b_dtype, 1)
# TMEM layout for P (from PV MMA's A operand — this is the KEY)
self.p_tmem_layout = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, a_dtype, 1)
# C/Output SMEM layout
self.c_smem_layout = utils.sm100.make_smem_layout_epi(c_dtype, c_layout, self.epi_tile, 2)
# TMEM allocation: two regions
# Region 0: scores (Q@K^T accumulator, QK MMA's C layout)
acc_shape = qk_mma.partition_shape_C(self.mma_tiler_mn)
tCtAcc_fake = qk_mma.make_fragment_C(cute.append(acc_shape, 1))
self.num_tmem_cols_scores = utils.get_num_tmem_alloc_cols(tCtAcc_fake, arch="sm_100")
# Region 1: output (P@V accumulator, PV MMA's C layout)
acc_shape_pv = pv_mma.partition_shape_C(self.mma_tiler_mn)
tCtO_fake = pv_mma.make_fragment_C(cute.append(acc_shape_pv, 1))
self.num_tmem_cols_output = utils.get_num_tmem_alloc_cols(tCtO_fake, arch="sm_100")
# Total
self.total_tmem_cols = max(self.num_tmem_cols_scores + self.num_tmem_cols_output, 256)
# tmem_p_offset: P TMEM starts at 0 (same as scores, P replaces scores in TMEM)
self.tmem_p_offset = 0
# TMA load bytes
q_smem = cute.slice_(self.q_smem_layout, (None, None, None, 0))
k_smem = cute.slice_(self.k_smem_layout, (None, None, None, 0))
self.num_tma_bytes = (
cute.size_in_bytes(a_dtype, q_smem) +
cute.size_in_bytes(b_dtype, k_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor,
stream: cuda.CUstream):
a_dtype = a.element_type
b_dtype = b.element_type
c_dtype = c.element_type
a_major = LayoutEnum.from_tensor(a).mma_major_mode()
b_major = LayoutEnum.from_tensor(b).mma_major_mode()
c_layout = LayoutEnum.from_tensor(c)
# QK MMA: Q @ K^T, A from SMEM
qk_mma = utils.sm100.make_trivial_tiled_mma(
a_dtype, b_dtype, a_major, b_major,
self.acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM,
)
# PV MMA: P @ V, A from TMEM, P is K-major
# Following NVIDIA fmha.py: P (intermediate) is K-major and from TMEM
p_major = cute.nvgpu.OperandMajorMode.K
pv_mma = utils.sm100.make_trivial_tiled_mma(
a_dtype, b_dtype, p_major, b_major,
self.acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.TMEM,
)
self._setup(qk_mma, pv_mma, a_dtype, b_dtype, c_dtype, a_major, b_major, c_layout)
q_smem = cute.slice_(self.q_smem_layout, (None, None, None, 0))
k_smem = cute.slice_(self.k_smem_layout, (None, None, None, 0))
v_smem = cute.slice_(self.v_smem_layout, (None, None, None, 0))
tma_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, q_smem, self.qk_mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, k_smem, self.qk_mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_v, tma_tensor_v = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
b, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_layout, mode=[0, 1])
tma_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(
qk_mma, pv_mma,
tma_a, tma_tensor_a, tma_b, tma_tensor_b,
tma_v, tma_tensor_v,
tma_c, tma_tensor_c, self.cluster_layout_vmnk,
self.q_smem_layout, self.k_smem_layout,
self.v_smem_layout, self.p_tmem_layout,
self.c_smem_layout, self.epi_tile,
).launch(grid=(1,1,1), block=[self.threads_per_cta, 1, 1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma,
tma_q, mQ, tma_k, mK,
tma_v, mV,
tma_c, mC, cl_vmnk,
q_smem_layout, k_smem_layout,
v_smem_layout, p_tmem_layout,
c_smem_layout, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q)
cpasync.prefetch_descriptor(tma_k)
cpasync.prefetch_descriptor(tma_v)
cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, 2]
dealloc_bar: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator()
storage = smem.allocate(SS)
ab_prod, ab_cons = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.ab_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_bytes, cta_layout_vmnk=cl_vmnk,
defer_sync=True).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 128),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=160)
tmem = utils.TmemAllocator(storage.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=0, is_two_cta=False,
two_cta_tmem_dealloc_mbar_ptr=storage.dealloc_bar.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sQ = smem.allocate_tensor(element_type=BFloat16, layout=q_smem_layout.outer,
byte_alignment=128, swizzle=q_smem_layout.inner)
sK = smem.allocate_tensor(element_type=BFloat16, layout=k_smem_layout.outer,
byte_alignment=128, swizzle=k_smem_layout.inner)
sV = smem.allocate_tensor(element_type=BFloat16, layout=v_smem_layout.outer,
byte_alignment=128, swizzle=v_smem_layout.inner)
sC = smem.allocate_tensor(element_type=BFloat16, layout=c_smem_layout.outer,
byte_alignment=128, swizzle=c_smem_layout.inner)
# Q and K TMA partition
gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None))
gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.qk_mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gQ, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ)
tCgK = qk_thr.partition_B(gK)
tCgC = qk_thr.partition_C(gC)
a_cta = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_cta, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3))
b_cta = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tAsK, tAgK = cpasync.tma_partition(tma_k, 0, b_cta, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3))
tAgQ = tAgQ[(None,0,None,0)]
tAgK = tAgK[(None,0,None,0)]
# QK MMA fragments
tCrQ = qk_mma.make_fragment_A(sQ)
tCrK = qk_mma.make_fragment_B(sK)
# PV MMA fragments
tCrV = pv_mma.make_fragment_B(sV)
# TMEM accumulator fake for QK (scores)
acc_shape_qk = qk_mma.partition_shape_C(self.mma_tiler_mn)
tCtS_fake = qk_mma.make_fragment_C(cute.append(acc_shape_qk, 1))
# TMEM accumulator fake for PV (output)
acc_shape_pv = pv_mma.partition_shape_C(self.mma_tiler_mn)
tCtO_fake = pv_mma.make_fragment_C(cute.append(acc_shape_pv, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ── TMA LOAD WARP ──
if warp_idx == self.tma_warp_id:
ab_prod.reset()
peek = ab_prod.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_prod.acquire_and_advance(peek)
cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_k, tAgK[(None,h.count)], tAsK[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_prod.try_acquire()
ab_prod.tail()
# ── MMA WARP ──
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
# TMEM region 0: scores (QK MMA accumulator)
tCtS_base = cute.make_tensor(tmem_ptr, tCtS_fake.layout)
tCtS = tCtS_base[(None,None,None,0)]
# TMEM region 1: output (PV MMA accumulator)
out_ptr = cute.recast_ptr(
tmem_ptr + self.num_tmem_cols_scores, dtype=self.acc_dtype)
tCtO_base = cute.make_tensor(out_ptr, tCtO_fake.layout)
tCtO = tCtO_base[(None,None,None,0)]
# ── P fragment for PV MMA (a_source=TMEM) ──
# Following NVIDIA fmha.py pattern:
# 1. Create tP tensor with P TMEM layout (from make_smem_layout_A for PV MMA)
# 2. make_fragment_A(tP) creates the A fragment with the right layout
# 3. Shift the fragment's iterator to the P TMEM offset
tP = cute.make_tensor(tmem_ptr, p_tmem_layout.outer)
tOrP_base = pv_mma.make_fragment_A(tP)
tOrP = tOrP_base[(None,None,None,0)]
# Shift iterator to the P offset (scores region = offset 0, so no shift needed)
# In general: tOrP_shifted = cute.make_tensor(
# tOrP.iterator + (self.acc_dtype.width // self.a_dtype.width) * self.tmem_p_offset,
# tOrP.layout)
ab_cons.reset()
peek = cutlass.Boolean(1)
peek = ab_cons.try_wait()
# ── QK MMA: Q @ K^T → tmem_scores ──
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_cons.wait_and_advance(peek)
nblk = cute.size(tCrQ, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
crd = (None, None, kb, h.index)
cute.gemm(qk_mma, tCtS, tCrQ[crd], tCrK[crd], tCtS)
h.release()
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_cons.try_wait()
# Fence TMEM writes (scores are ready)
cute.arch.fence_view_async_tmem_store()
# ── PV MMA: P @ V → tmem_output ──
# A = P from TMEM (a_source=TMEM), B = V from SMEM
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
# V fragment: slice to remove stage dimension
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tCtO, tOrP[(None, None, kb)], tCrV_s[(None, None, kb)], tCtO)
# Signal output ready
acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, 1)
acc_pipe.producer_acquire(acc_st)
acc_pipe.producer_commit(acc_st)
acc_st.advance()
acc_pipe.producer_tail(acc_st)
# ── EPILOGUE WARPS ──
if warp_idx < self.mma_warp_id:
tmem.allocate(self.total_tmem_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
out_ptr = cute.recast_ptr(
tmem_ptr + self.num_tmem_cols_scores, dtype=self.acc_dtype)
tCtO_base = cute.make_tensor(out_ptr, tCtO_fake.layout)
cons = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 128)
c_pipe = pipeline.PipelineTmaStore.create(num_stages=2, producer_group=c_grp)
cons = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), cons, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test_stage_b():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device="cuda")
k = torch.randn(n, k, 1, dtype=torch.bfloat16, device="cuda")
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device="cuda")
# Reference: Q @ K^T @ V (K=V for now)
qf = q[:,:,0].float()
kf = k[:,:,0].float()
ref = qf @ kf.T @ kf # Q @ K^T @ V (with V=K)
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBKernel(mma_tiler_mn=(128, 128))
print("Compiling...", flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print("Running...", flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (out - ref).abs().max().item()
print("Stage B: Q @ K^T @ V (a_source=TMEM, P TMEM layout from PV MMA)")
print(" Cosine: {:.6f}, Max error: {:.6f}".format(cos, max_err))
print(" {}".format("PASS" if cos >= 0.99 else "FAIL"))
return cos
if __name__ == "__main__":
test_stage_b()

259
tests/test_stage_b_v4.py Normal file
View File

@@ -0,0 +1,259 @@
"""Stage B v4: Two MMAs with a_source=TMEM, minimal pipeline.
No softmax. P = raw scores. Reference: Q @ K^T @ V."""
import torch
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
class StageBKernel:
def __init__(self, mma_tiler_mn):
self.acc_dtype = Float32
self.mma_tiler_mn = mma_tiler_mn
self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.use_2cta_instrs = False
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.threads_per_cta = 192
self.epilog_sync_bar_id = 1
self.num_c_stage = 2
@cute.jit
def __call__(self, a, b, c, stream):
a_dtype = a.element_type
b_dtype = b.element_type
c_dtype = c.element_type
a_major = LayoutEnum.from_tensor(a).mma_major_mode()
b_major = LayoutEnum.from_tensor(b).mma_major_mode()
c_layout = LayoutEnum.from_tensor(c)
self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.c_dtype = c_dtype
self.c_layout = c_layout
qk_mma = utils.sm100.make_trivial_tiled_mma(
a_dtype, b_dtype, a_major, b_major,
self.acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
a_dtype, b_dtype, cute.nvgpu.OperandMajorMode.K, b_major,
self.acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.TMEM)
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0], self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, c_layout, c_dtype)
q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, a_dtype, 1)
k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, b_dtype, 1)
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, a_dtype, 1)
c_smem_s = utils.sm100.make_smem_layout_epi(c_dtype, c_layout, self.epi_tile, 2)
acc_shape = qk_mma.partition_shape_C(self.mma_tiler_mn)
tCtS_fake = qk_mma.make_fragment_C(cute.append(acc_shape, 1))
self.num_tmem_cols_scores = utils.get_num_tmem_alloc_cols(tCtS_fake, arch="sm_100")
acc_shape_pv = pv_mma.partition_shape_C(self.mma_tiler_mn)
tCtO_fake = pv_mma.make_fragment_C(cute.append(acc_shape_pv, 1))
self.num_tmem_cols_output = utils.get_num_tmem_alloc_cols(tCtO_fake, arch="sm_100")
self.total_tmem_cols = max(self.num_tmem_cols_scores + self.num_tmem_cols_output, 256)
q_smem = cute.slice_(q_smem_s, (None, None, None, 0))
k_smem = cute.slice_(k_smem_s, (None, None, None, 0))
self.num_tma_bytes = (cute.size_in_bytes(a_dtype, q_smem) + cute.size_in_bytes(b_dtype, k_smem)) * cute.size(qk_mma.thr_id.shape)
tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, q_smem, self.qk_mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, k_smem, self.qk_mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_c, tma_tc,
self.cluster_layout_vmnk, q_smem_s, k_smem_s, p_tmem_s, c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[192,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_c, mC, cl_vmnk,
q_smem_s, k_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q)
cpasync.prefetch_descriptor(tma_k)
cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, 2]
dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator()
st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 128),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=160)
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=0, is_two_cta=False,
two_cta_tmem_dealloc_mbar_ptr=st.dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sQ = smem.allocate_tensor(element_type=BFloat16, layout=q_smem_s.outer, byte_alignment=128, swizzle=q_smem_s.inner)
sK = smem.allocate_tensor(element_type=BFloat16, layout=k_smem_s.outer, byte_alignment=128, swizzle=k_smem_s.inner)
sC = smem.allocate_tensor(element_type=BFloat16, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None))
gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.qk_mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gQ, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ)
tCgK = qk_thr.partition_B(gK)
tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tAsK, tAgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3))
tAgQ = tAgQ[(None,0,None,0)]
tAgK = tAgK[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ)
tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sK)
acc_shape = qk_mma.partition_shape_C(self.mma_tiler_mn)
tCtS_fake = qk_mma.make_fragment_C(cute.append(acc_shape, 1))
acc_shape_pv = pv_mma.partition_shape_C(self.mma_tiler_mn)
tCtO_fake = pv_mma.make_fragment_C(cute.append(acc_shape_pv, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA warp
if warp_idx == self.tma_warp_id:
ab_p.reset()
peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_k, tAgK[(None,h.count)], tAsK[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# MMA warp
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtS_base = cute.make_tensor(tmem_ptr, tCtS_fake.layout)
tCtS = tCtS_base[(None,None,None,0)]
out_ptr = cute.recast_ptr(tmem_ptr + self.num_tmem_cols_scores, dtype=self.acc_dtype)
tCtO_base = cute.make_tensor(out_ptr, tCtO_fake.layout)
tCtO = tCtO_base[(None,None,None,0)]
# P fragment for PV MMA (fmha.py pattern)
tP = cute.make_tensor(tmem_ptr, p_tmem_s.outer)
tOrP = pv_mma.make_fragment_A(tP)[(None,None,None,0)]
ab_c.reset()
peek = ab_c.try_wait()
# QK MMA
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrQ, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tCtS, tCrQ[(None,None,kb,h.index)], tCrK[(None,None,kb,h.index)], tCtS)
h.release()
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
# fence removed for debug
# PV MMA
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None,None,None,0)]
nblk_pv = cute.size(tOrP, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tCtO, tOrP[(None,None,kb)], tCrV_s[(None,None,kb)], tCtO)
acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, 1)
acc_pipe.producer_acquire(acc_st)
acc_pipe.producer_commit(acc_st)
acc_st.advance()
acc_pipe.producer_tail(acc_st)
# Epilogue
if warp_idx < self.mma_warp_id:
tmem.allocate(self.total_tmem_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
out_ptr = cute.recast_ptr(tmem_ptr + self.num_tmem_cols_scores, dtype=self.acc_dtype)
tCtO_base = cute.make_tensor(out_ptr, tCtO_fake.layout)
cons = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 128)
c_pipe = pipeline.PipelineTmaStore.create(num_stages=2, producer_group=c_grp)
cons = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), cons, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
ref = q[:,:,0].float() @ kv[:,:,0].float().T @ kv[:,:,0].float()
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBKernel((128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('Cosine: {:.6f}'.format(cos))
print('{}'.format('PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

341
tests/test_stage_b_v5.py Normal file
View File

@@ -0,0 +1,341 @@
"""
Stage B: Two MMAs + Identity Softmax with Layout Transform
Following fmha.py's synchronization pattern:
- MMA↔softmax sync via PipelineUmmaAsync (mma_si pipeline)
- MMA produces scores (after QK), softmax consumes
- Softmax produces P, MMA re-acquires (before PV)
- Identity softmax: tcgen05.ld from C-layout → F32→BF16 → tcgen05.st to A-layout
Reference: output = (Q @ K^T) @ V
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
class StageBIdentitySoftmax:
def __init__(self, mma_tiler_mn):
self.acc_dtype = Float32
self.qk_acc_dtype = Float32
self.q_dtype = BFloat16
self.o_dtype = BFloat16
self.mma_tiler_mn = mma_tiler_mn
self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.use_2cta_instrs = False
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.threads_per_cta = 192
self.num_c_stage = 2
# TMEM offsets (fmha.py for 128x128)
self.tmem_s0_offset = 0
self.tmem_o0_offset = 256
self.tmem_p0_offset = 32
self.tmem_alloc_cols = 512
self.epilog_sync_bar_id = 1
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = tuple(self.qk_mma_tiler)
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.a_dtype, 1)
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.b_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
acc_shape = qk_mma.partition_shape_C(self.mma_tiler_mn)
tCtS_fake = qk_mma.make_fragment_C(cute.append(acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtS_fake, arch="sm_100")
q_smem = cute.slice_(self.q_smem_s, (None, None, None, 0))
k_smem = cute.slice_(self.k_smem_s, (None, None, None, 0))
self.num_tma_bytes = (cute.size_in_bytes(self.a_dtype, q_smem) + cute.size_in_bytes(self.b_dtype, k_smem)) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a, b, c, stream):
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major, self.b_major, self.acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major, self.acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
q_smem = cute.slice_(self.q_smem_s, (None, None, None, 0))
k_smem = cute.slice_(self.k_smem_s, (None, None, None, 0))
tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, q_smem, self.qk_mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, k_smem, self.qk_mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_c, tma_tc,
self.cluster_layout_vmnk, self.q_smem_s, self.k_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[192,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_c, mC, cl_vmnk,
q_smem_s, k_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q)
cpasync.prefetch_descriptor(tma_k)
cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, 2] # AB pipeline (1 stage)
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] # MMA↔softmax pipeline (1 stage)
acc_bar: cute.struct.MemRange[cutlass.Int64, 2] # ACC pipeline (1 stage)
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator()
st = smem.allocate(SS)
# AB pipeline (TMA load)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
# MMA↔softmax pipeline (following fmha.py's mma_s0 pattern)
# Producer = MMA warp (after QK: commit scores; before PV: re-acquire P)
# Consumer = softmax warps (wait for scores, process, release P)
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 128),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
# ACC pipeline (PV output → epilogue)
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 128),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
# TMEM allocator
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=160)
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=0, is_two_cta=False,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sQ = smem.allocate_tensor(element_type=BFloat16, layout=q_smem_s.outer, byte_alignment=128, swizzle=q_smem_s.inner)
sK = smem.allocate_tensor(element_type=BFloat16, layout=k_smem_s.outer, byte_alignment=128, swizzle=k_smem_s.inner)
sC = smem.allocate_tensor(element_type=BFloat16, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None))
gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.qk_mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gQ, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tAsK, tAgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tAgK = tAgK[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ)
tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sK)
# TMEM tensors
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler_mn)
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_mma.partition_shape_C(self.mma_tiler_mn)
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# P fragment for PV MMA (A-layout from TMEM)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_mma.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
# Fake acc for epilogue
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ── TMA WARP ──
if warp_idx == self.tma_warp_id:
ab_p.reset()
peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_k, tAgK[(None,h.count)], tAsK[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# ── MMA WARP ──
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
# 1. Acquire S0 buffer (before QK)
s0_handle = mma_si_prod.acquire_and_advance()
# 2. QK MMA: Q @ K^T → tmem_scores
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrQ, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,h.index)], tCrK[(None,None,kb,h.index)], tStS0)
h.release()
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
# 3. Fence TMEM, then release S0 for softmax
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
# 4. Re-acquire S0 (wait for softmax to finish and release P)
s0_handle = mma_si_prod.acquire_and_advance()
# 5. PV MMA: P @ V → tmem_output
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
# 6. Release output for epilogue
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, 1)
acc_pipe.producer_acquire(acc_prod_st)
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# ── SOFTMAX / EPILOGUE WARPS ──
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
# Identity softmax setup (following fmha.py)
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
tilePlikeFP32 = self.qk_mma_tiler[1] // 32 * self.o_dtype.width
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32)))
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
# Wait for scores
si_handle = mma_si_cons.wait_and_advance()
# Load from C-layout
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# Identity: F32 → BF16
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
tTMEM_STORErS_x4_e = cute.make_tensor(
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
tTMEM_LOADrS.layout)
s_vec = tTMEM_LOADrS.load()
tTMEM_STORErS_x4_e.store(s_vec.to(self.q_dtype))
# Store to A-layout
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
cute.arch.fence_view_async_tmem_store()
# Release back to MMA warp
si_handle.release()
# Epilogue
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 128)
c_pipe = pipeline.PipelineTmaStore.create(num_stages=2, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m,n,k = 128,128,128
q = torch.randn(m,k,1,dtype=torch.bfloat16,device='cuda')
kv = torch.randn(n,k,1,dtype=torch.bfloat16,device='cuda')
c = torch.zeros(m,n,1,dtype=torch.bfloat16,device='cuda')
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
ref = qf @ kvf.T @ kvf
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBIdentitySoftmax((128,128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (out - ref).abs().max().item()
print('Stage B (identity softmax, PipelineUmmaAsync sync)')
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
print(' {}'.format('PASS' if cos>=0.99 else 'FAIL'))
if __name__ == '__main__':
test()

342
tests/test_stage_b_v6.py Normal file
View File

@@ -0,0 +1,342 @@
"""
Stage B: Two MMAs + Identity Softmax with Layout Transform
Following fmha.py's softmax_step pattern exactly.
Architecture:
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
Identity softmax: tcgen05.ld from C-layout → convert F32→BF16 → tcgen05.st to A-layout
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
Reference: output = (Q @ K^T) @ V (no softmax, P = raw scores)
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
class StageBIdentitySoftmax:
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
self.num_c_stage = 2
# TMEM offsets (fmha.py for 128x128)
self.tmem_s0_offset = 0; self.tmem_o0_offset = 256; self.tmem_p0_offset = 32
self.tmem_alloc_cols = 512
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1],
self.qk_mma_tiler[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
# TMEM alloc cols — use the LARGER of QK and PV fragment sizes
qk_acc_shape = qk_mma.partition_shape_C(self.mma_tiler[:2])
qk_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(qk_fake, arch="sm_100")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] # MMA↔softmax pipeline (1 stage)
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
# AB pipeline
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
# MMA↔softmax pipeline
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
# ACC pipeline (PV output → epilogue) — KEY FIX: use warp count, NOT thread count
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
# TMEM allocator
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sB)
# TMEM tensors
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# P fragment for PV MMA (A-layout from TMEM)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_mma.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ── TMA WARP ──
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# ── MMA WARP ──
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
# 1. Acquire S0 buffer
s0_handle = mma_si_prod.acquire_and_advance()
# 2. QK MMA: Q @ K^T → tmem_scores
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
# 3. Fence TMEM, release S0 for softmax
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
# 4. Re-acquire S0 (wait for softmax to finish)
s0_handle = mma_si_prod.acquire_and_advance()
# 5. PV MMA: P @ V → tmem_output
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
# 6. Release output for epilogue
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# ── SOFTMAX / EPILOGUE WARPS ──
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
# ── Identity softmax: C-layout → A-layout transform ──
# 1. LOAD pipeline (reads scores from QK C-layout)
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# 2. STORE pipeline (writes P in A-layout at tmem_p0_offset)
tilePlikeFP32 = self.qk_mma_tiler[1] // 32 * self.o_dtype.width
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32)))
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
# 3. Wait for scores (from MMA warp via mma_si pipeline)
si_handle = mma_si_cons.wait_and_advance()
# 4. Load from C-layout
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# 5. IDENTITY: F32 → BF16
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
tTMEM_STORErS_x4_e = cute.make_tensor(
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
tTMEM_LOADrS.layout)
s_vec = tTMEM_LOADrS.load()
tTMEM_STORErS_x4_e.store(s_vec.to(self.q_dtype))
# 6. Store into A-layout
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
cute.arch.fence_view_async_tmem_store()
# 7. Release S0 back to MMA warp
si_handle.release()
# ── Epilogue ──
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
ref = qf @ kvf.T @ kvf
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (out - ref).abs().max().item()
print('Stage B: (Q @ K^T) @ V with identity softmax layout transform')
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

374
tests/test_stage_b_v7.py Normal file
View File

@@ -0,0 +1,374 @@
"""
Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets.
Key fixes over v6:
- TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols)
- P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly)
- tilePlikeFP32 computed from qk_mma_tiler and dtype widths
- tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments
- JIT-time diagnostic prints of all TMEM sizes
Architecture (matches fmha.py exactly):
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class StageBIdentitySoftmax:
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1],
self.qk_mma_tiler[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
# ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ──
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
s_cols = find_tmem_tensor_col_offset(tStS)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
o_cols = find_tmem_tensor_col_offset(tOtO)
# tilePlikeFP32 for the store-side composition
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
# ── TMEM LAYOUT (matching fmha.py) ──
# P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout)
# in the same TMEM region. The A-layout view starts partway through the S region.
# fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered)
# The P offset of 32 means the A-layout starts at column 32 within the S region.
# This is because the C-layout and A-layout partition TMEM differently per-thread;
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
#
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32 # Same as fmha.py
self.tmem_o0_offset = s_cols # 128
self.tmem_alloc_cols = s_cols + o_cols # 256
# Also compute via get_num_tmem_alloc_cols for the full allocation
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
print(f"[StageB] s_cols (QK accumulator) = {s_cols}")
print(f"[StageB] o_cols (PV accumulator) = {o_cols}")
print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}")
print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}")
print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}")
print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}")
print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}")
print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
# V shares the same SMEM as B (same data, different layout for PV MMA)
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# P fragment: construct from p_tmem_s layout (matching fmha.py exactly)
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
# tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ── TMA WARP ──
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# ── MMA WARP ──
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
s0_handle = mma_si_prod.acquire_and_advance()
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# ── SOFTMAX / EPILOGUE WARPS ──
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# ── Identity softmax: C-layout → A-layout transform ──
# 1. LOAD pipeline (reads scores from QK C-layout)
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# 2. STORE pipeline (writes P in A-layout — same as fmha.py softmax_step)
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
# 3. Wait for scores
si_handle = mma_si_cons.wait_and_advance()
# 4. Load from C-layout
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# 5. IDENTITY: F32 → BF16
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
tTMEM_STORErS_x4_e = cute.make_tensor(
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
tTMEM_LOADrS.layout)
s_vec = tTMEM_LOADrS.load()
tTMEM_STORErS_x4_e.store(s_vec.to(self.q_dtype))
# 6. Store into A-layout
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
cute.arch.fence_view_async_tmem_store()
# 7. Release back to MMA warp
si_handle.release()
# ── Epilogue ──
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
ref = qf @ kvf.T @ kvf
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (out - ref).abs().max().item()
print('Stage B v7: (Q @ K^T) @ V with identity softmax (computed TMEM offsets)')
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,263 @@
"""
TMEM Addressing Test: verify offset computation from layouts.
Allocates TMEM, computes offsets from QK accumulator and PV fragment sizes,
writes known values via tcgen05.st at each offset region, reads them back
via tcgen05.ld, and verifies correctness. No MMA, no softmax, no V load.
This validates that our offset arithmetic is correct before wiring it into Stage B.
"""
import torch
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
class TmemAddressingTest:
def __init__(self, mma_tiler_mn):
self.acc_dtype = Float32
self.qk_acc_dtype = Float32
self.q_dtype = BFloat16
self.o_dtype = BFloat16
self.mma_tiler_mn = mma_tiler_mn
self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.threads_per_cta = 192
self.tmem_alloc_sync_bar_id = 2
self.tmem_dealloc_sync_bar_id = 3
self.num_c_stage = 2
@cute.jit
def __call__(self, debug_buf: cute.Tensor, stream: cuda.CUstream):
self.a_dtype = BFloat16
self.b_dtype = BFloat16
self.a_major = cute.nvgpu.OperandMajorMode.K
self.b_major = cute.nvgpu.OperandMajorMode.K
self.c_layout = LayoutEnum.RowMajor
# Create the same MMAs as Stage B to get the same fragment layouts
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.TMEM)
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1],
self.qk_mma_tiler[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((1, 1, 1)), (qk_mma.thr_id.shape,))
# Compute TMEM fragment sizes from layouts
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
qk_acc_cols = cute.size(tStS.layout, mode=[1])
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
pv_acc_cols = cute.size(tOtO.layout, mode=[1])
# P operand size: tilePlikeFP32 = qk_mma_tiler[1] * q_dtype.width // 32
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
# Compute offsets
tmem_s_offset = 0
tmem_p_offset = qk_acc_cols # P right after QK accumulator
tmem_o_offset = qk_acc_cols + tilePlikeFP32 # O right after P
# Total allocation
tmem_alloc_cols = tmem_o_offset + pv_acc_cols
# JIT-time prints — these appear during compilation
print(f"[TMEM] qk_acc_cols = {qk_acc_cols}")
print(f"[TMEM] tilePlikeFP32 = {tilePlikeFP32}")
print(f"[TMEM] pv_acc_cols = {pv_acc_cols}")
print(f"[TMEM] tmem_s_offset = {tmem_s_offset}")
print(f"[TMEM] tmem_p_offset = {tmem_p_offset}")
print(f"[TMEM] tmem_o_offset = {tmem_o_offset}")
print(f"[TMEM] tmem_alloc_cols = {tmem_alloc_cols}")
self._kernel(
qk_mma, pv_mma, tStS, tOtO, tmem_alloc_cols,
tmem_s_offset, tmem_p_offset, tmem_o_offset, tilePlikeFP32,
debug_buf, self.cluster_layout_vmnk
).launch(grid=(1, 1, 1), block=[self.threads_per_cta, 1, 1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tStS, tOtO, tmem_alloc_cols,
tmem_s_offset, tmem_p_offset, tmem_o_offset, tilePlikeFP32,
debug_buf, cl_vmnk):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
@cute.struct
class SS:
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator()
st = smem.allocate(SS)
tmem_bar = pipeline.NamedBarrier(
barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(
st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0],
is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
pipeline.pipeline_init_wait(cluster_shape_mvnk=cl_vmnk)
# ── MMA WARP: allocate TMEM, write test values ──
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
# Create TMEM tensors at computed offsets
# Scores region: write 1.0
tStS0 = cute.make_tensor(tStS.iterator + tmem_s_offset, tStS.layout)
# P region: write 2.0
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + tmem_p_offset, tStS_P_layout)
# Output region: write 3.0
tOtO0 = cute.make_tensor(tOtO.iterator + tmem_o_offset, tOtO.layout)
# Use tcgen05.st to write known values into each region
# We'll use the store copy atom
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# Store to scores region (value = 1.0)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype)
tiled_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS0)
thr_store = tiled_store.get_slice(sfw_idx)
tTMEM_STOREtS = thr_store.partition_D(tStS0)
# We need a source tensor with the same shape
tTMEM_STOREcS = thr_store.partition_S(
cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])))
# Load from scores region (verify readback)
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype)
tiled_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
# The MMA warp doesn't do the ld/st — the epilogue warps do.
# For this test, just signal that TMEM is ready, epilogue will verify.
# But actually, MMA warp CAN write to TMEM via cute.fill or direct MMA.
# The simplest test: MMA warp issues a QK MMA with accumulate=False (known result),
# then epilogue warps tcgen05.ld from the scores region and dump to debug_buf.
# For now: the MMA warp just signals and the epilogue does the verification.
# We'll write test values using tcgen05.st from epilogue warps (they have the copy atoms).
pass
# ── EPILOGUE WARPS: allocate TMEM, write test values, read back ──
if warp_idx < self.mma_warp_id:
tmem.allocate(tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# ── Write 1.0 to scores region ──
tStS0 = cute.make_tensor(tStS.iterator + tmem_s_offset, tStS.layout)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype)
tiled_store_s = tcgen05.make_tmem_copy(tmem_store_atom, tStS0)
thr_store_s = tiled_store_s.get_slice(sfw_idx)
tTMEM_STOREtS = thr_store_s.partition_D(tStS0)
tScS_s = qk_mma.get_slice(0).partition_C(
cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])))
tTMEM_STOREcS = thr_store_s.partition_S(tScS_s)
# Create register tensor filled with 1.0
tTMEM_STORErS = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.acc_dtype)
# Fill with 1.0
for i in cutlass.range(cute.size(tTMEM_STORErS), unroll_full=True):
tTMEM_STORErS.store(i, cutlass.Float32(1.0))
cute.copy(tiled_store_s, tTMEM_STORErS, tTMEM_STOREtS)
cute.arch.fence_view_async_tmem_store()
# ── Read back from scores region ──
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype)
tiled_load_s = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load_s = tiled_load_s.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load_s.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_mma.get_slice(0).partition_C(cS)
tTMEM_LOADcS = thr_load_s.partition_D(tScS)
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.acc_dtype)
cute.copy(tiled_load_s, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# Dump one value per thread to debug_buf for verification
# debug_buf shape: (threads_per_cta,) Float32
# Only epilogue warps (0..3, 128 threads) write
if tidx < 128:
val = tTMEM_LOADrS.load()
# Store first element of the loaded vector
debug_buf[tidx] = val # type: ignore
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test_tmem_addressing():
device = torch.device("cuda")
debug_buf = torch.zeros(128, dtype=torch.float32, device=device)
import cutlass.torch as cutlass_torch
mD = cutlass_torch.from_dlpack(debug_buf).mark_layout_dynamic(
leading_dim=cutlass_torch.get_leading_dim(debug_buf))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = TmemAddressingTest(mma_tiler_mn=(128, 128))
print("Compiling TMEM addressing test...", flush=True)
compiled = cute.compile(kernel, mD, stream)
print("Running...", flush=True)
compiled(mD, stream)
torch.cuda.synchronize()
print("Debug buffer (first 16 values):", debug_buf[:16].tolist())
# All values should be 1.0 if addressing is correct
nonzero = (debug_buf[:128] != 0).sum().item()
ones = (debug_buf[:128] == 1.0).sum().item()
print(f"Non-zero: {nonzero}/128, Ones: {ones}/128")
if nonzero > 0:
print("PASS: TMEM addressing works — read back non-zero values")
else:
print("FAIL: All zeros — TMEM addressing broken")
if __name__ == "__main__":
test_tmem_addressing()