From f3a2b37d70c20edd6c64e97edf64509a04878658 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 04:06:27 +0000 Subject: [PATCH] NVFP4-1.1: document CuTeDSL float-to-int limitation, revise approach to compact SwiGLU output --- NVFP4-1.1_APPROACH.md | 345 ++++++------------------------------------ 1 file changed, 50 insertions(+), 295 deletions(-) diff --git a/NVFP4-1.1_APPROACH.md b/NVFP4-1.1_APPROACH.md index ad9d7ece..5deaec27 100644 --- a/NVFP4-1.1_APPROACH.md +++ b/NVFP4-1.1_APPROACH.md @@ -1,305 +1,60 @@ -Let me read the relevant files first.The file is very long. Let me read the quantize code and the epilogue section more carefully.Now let me look at the current MoE forward flow more carefully to understand the full picture of what needs to change.Now I have a clear picture. This is a significant kernel modification. Let me think through the approach carefully before coding. +# NVFP4-1.1 Approach Update (2026-05-28) -The key insight: the fused SwiGLU kernel already has the SwiGLU result in registers (in the `acc_vec_bf16` after the SiLU(gate)*up computation). We need to add FP4 quantization there instead of writing BF16 to GMEM. +## CuTeDSL Float-to-Int Limitation -The challenge is that `epi_tile_n=8` gives us 8-column subtiles, but NVFP4 needs 16-element blocks. Two consecutive *up* subtiles give 16 elements, but they're not consecutive in the subtile iteration (gate and up alternate). +**CuTeDSL CANNOT convert Float32 to Int32.** Both `cutlass.Int32(float_val)` and `float_val.to(cutlass.Int32)` fail with "LLVM ERROR: unsupported operation" during PTX lowering. The MLIR `arith.FloatToSIOp` is generated but the LLVM backend cannot lower it. -Let me spawn a subagent for this — it's a substantial piece of work.Subagent can't reach the model. I'll do this directly. Let me start with the kernel changes.Let me look at the epilogue section more carefully — the exact register operations I need to modify:OK, I now have a thorough understanding of the code. This is a substantial modification. Let me think about the right approach before writing a single line. +This blocks in-kernel FP4 pack, which requires: +1. FP8 E4M3 bit pattern computation (exponent + mantissa as integers) +2. E2M1 nibble index computation (half_step → index as integer) +3. Nibble packing into bytes (bit shifts and OR on integers) -The core challenge: the SwiGLU result in the up subtile has 8 BF16 values per thread. NVFP4 needs 16-element blocks. Two up subtiles give 16 values per thread. I need to accumulate amax across two up subtiles, then pack. +## What works in CuTeDSL +- `cute.arch.fmax`, `cute.arch.fmin` — float min/max ✅ +- `cute.floor` — floor function ✅ +- `cute.absf` — float abs ✅ +- `cute.arch.load` / `cute.arch.store` — scalar GMEM I/O ✅ +- `cute.arch.cvt_i8_bf16` — int8 → BF16 (one-way) ✅ +- `cute.arch.cvt_f4e2m1_f16` — FP4 → BF16 (one-way) ✅ +- Float arithmetic (+, -, *, /) ✅ +- Type casts: Float32 ↔ BFloat16 ✅ +- Int32 arithmetic (shifts, OR, AND) ✅ (but can't create Int32 from Float32!) -But CuTeDSL doesn't have FP4 pack primitives. I'd need to implement the entire quantization in CuTeDSL register operations (amax, FP8 cast, E2M1 nearest, nible pack). This is complex and error-prone in CuTeDSL. +## What does NOT work +- Float32 → Int32 conversion (any method) +- Inline PTX / inline assembly +- BF16 → Int8 (reverse of cvt_i8_bf16) +- BF16 → FP4 (reverse of cvt_f4e2m1_f16) -**Better approach: Use the existing SMEM→BF16→GMEM path, then add a *second kernel pass* that does the quantize.** But that's just the current approach with an extra kernel launch. +## Revised approach: Compact SwiGLU output -**Even better: Modify the deinterleave_quantize_nvfp4_cuda kernel to read from the L1 output in-place.** This is already what it does — we're not saving a kernel launch, we're saving the BF16 GMEM materialization. +Since in-kernel FP4 pack is blocked, the best optimization within CuTeDSL's capabilities is: -Wait — the whole point of NVFP4-1.1 is to eliminate the BF16 GMEM write between L1 and L2. The current flow: -1. L1 kernel writes BF16 to `l1_out` (GMEM) -2. `l1_out_real = l1_out[padded_dst]` — gathers from padded layout -3. `deinterleave_quantize_nvfp4_cuda(l1_out_real, ...)` — reads BF16, writes FP4 -4. Scatter FP4 back into padded layout for L2 +**Modify the SwiGLU epilogue to skip writing gate subtiles to GMEM.** -If we fuse FP4 into the L1 kernel, we'd write FP4 directly (skipping steps 2-4). But the FP4 output also needs to go into the padded layout for L2. - -Actually, looking at the MoE layer more carefully: -- L1 writes to `out` which is `(total_padded_slots, 2*intermediate)` BF16 -- Then `l1_out_real = l1_out[padded_dst]` gathers the real rows from the padded output -- Then `deinterleave_quantize_nvfp4_cuda` deinterleaves and quantizes -- Then the FP4 is scattered into `padded_activated_fp4` - -So even with in-kernel FP4, we'd still need the scatter step. The win is: -- Eliminating the BF16 GMEM write (replaced by FP4, which is 4× smaller) -- Eliminating the `deinterleave_quantize_nvfp4_cuda` kernel launch -- Eliminating the `l1_out_real = l1_out[padded_dst]` gather - -But the padded layout means the kernel needs to write FP4 to the padded positions, not just the real rows. The existing kernel handles this because the GEMM naturally produces padded output (the C tensor covers all padded slots). - -Actually, wait. The L1 kernel already writes to padded positions. If we fuse FP4 output, the kernel would write FP4 to the padded `fp4_out` buffer directly. No gather/scatter needed. The L2 GEMM reads from the padded FP4 buffer. This IS the win. - -But the FP4 output shape is different from BF16: (total_padded_slots, intermediate//2) instead of (total_padded_slots, 2*intermediate). And we need separate FP8 scale factors: (total_padded_slots, intermediate//16). - -And the L2 GEMM's scale_a needs to be assembled from the FP8 scales. Currently scale_a for L2 is assembled from `slot_l2_x_sf` which comes from `deinterleave_quantize_nvfp4_cuda`. - -This is getting very involved. Let me take a pragmatic approach and break it into pieces: - -**Step 1: Add FP4 output capability to the kernel** -- Add `fp4_out`, `sf_out` tensors as kernel parameters -- In the epilogue, after computing SwiGLU for up subtiles: - - Write BF16 to SMEM as before (for the C TMA store) - - Also accumulate SwiGLU values in a register buffer - - After 2 up subtiles, compute amax + FP4 pack and write to fp4_out/sf_out GMEM - -**Step 2: Update the runner to pass fp4_out/sf_out** - -**Step 3: Update MoE layer to use fused path** - -Actually, let me reconsider. The FP4 pack in CuTeDSL is the hardest part. CuTeDSL is a Python DSL that JIT-compiles to GPU code. It has limited operations. Let me check what math operations are available... - -The existing SwiGLU code already uses: -- `cute.exp(neg_acc)` — exponential -- `cute.math.fmin`, `cute.math.fmax` — min/max -- `.to(dtype)` — type casts - -For FP4 pack I need: -- amax (abs + max) — can do with `cute.math.fmax` -- FP8 E4M3 cast — **this is the problem.** CuTeDSL doesn't have FP8 cast operations. -- E2M1 nearest — need to implement the step table -- Nibble pack — bit manipulation - -Actually, `self.sf_dtype` is `Float8E4M3FN` in CuTeDSL. But I don't think CuTeDSL supports casting FP32 → FP8 directly. - -This is a fundamental blocker. Without FP8 cast and FP4 pack in CuTeDSL, we can't do the quantization in the kernel. - -**Alternative: Use a small CUDA C++ device function called from CuTeDSL.** But CuTeDSL doesn't support calling arbitrary device functions. - -**Alternative: Do the quantization in BF16/FP32 and write the FP4/FP8 bit patterns as integers.** This is hacky but might work: -1. Compute amax in FP32 -2. Compute the FP8 E4M3 scale bits manually: extract exponent and mantissa, pack as uint8 -3. Divide each value by scale, find nearest E2M1, pack nibbles as uint8 -4. Write the uint8 values to GMEM using a TMA or regular store - -But CuTeDSL's TMA store expects the tensor to have a specific dtype (float4_e2m1fn_x2 for FP4, float8_e4m3fn for FP8). We'd need to write as uint8 and then reinterpret. - -Actually, looking at the existing code, `quantize_nvfp4_gpu` and `deinterleave_quantize_nvfp4_cuda` are CUDA C++ kernels. The quantization IS being done on the GPU, just in a separate kernel. The whole point of NVFP4-1.1 is to fuse it into the epilogue to save the kernel launch and BF16 GMEM write. - -Given the CuTeDSL limitations, I think the most practical approach is: - -**Use a two-phase approach within the same kernel:** -1. Phase 1: SwiGLU in registers, write BF16 to SMEM (existing path) -2. Phase 2: Read BF16 from SMEM, quantize to FP4 in registers, write FP4/FP8 to GMEM - -But Phase 2 requires reading from SMEM (which the epilogue already has access to). The issue is that the C buffer SMEM is double-buffered for TMA stores and gets reused. - -**Alternative: Allocate extra SMEM for the quantization staging area.** - -Actually, wait. Let me re-read the existing flow. The C buffer is used for BF16 staging. After the TMA store, the C buffer is freed. If I skip the TMA store for the FP4 path, I can use the C buffer for staging. - -Here's the idea for the FP4 path: -1. For gate subtiles: compute SiLU(gate), save to `silu_gate_buf`. Skip C buffer write and TMA store. -2. For up subtiles: compute SwiGLU. Write BF16 to C buffer as before. -3. After 2 up subtiles: read the 16 BF16 values from the 2 C buffer slots, compute amax, pack FP4, write to GMEM directly (not via TMA, just regular GMEM store since we're already in the epilogue warps). - -Wait, regular GMEM stores from CuTeDSL? I'm not sure if that's straightforward... - -Actually, I realize I'm going in circles. Let me take the simplest possible approach that actually works: - -**Write BF16 to GMEM as before. Add a separate small CUDA C++ kernel that does the deinterleave + quantize.** - -But that's literally what we already have. The only optimization would be to not write the gate columns at all, saving half the BF16 GMEM bandwidth. - -**Even simpler first step: Skip writing gate subtiles to GMEM.** - -The current kernel writes both gate and SwiGLU result to GMEM interleaved format. The L2 GEM only needs the SwiGLU result. If we skip writing gate subtiles: -- BF16 GMEM write between L1 and L2 is halved (from 2*intermediate to intermediate per row) -- The deinterleave step becomes trivial (the data is already contiguous) -- The quantize kernel reads contiguous BF16 instead of interleaved - -This is a quick win with minimal kernel changes. Then the full FP4 fusion can be done later. - -Actually no, the user specifically asked me to complete NVFP4-1.1. Let me not half-ass it. Let me think about what's actually needed. - -OK, I think the right approach, given CuTeDSL's limitations, is: - -1. **In the epilogue**: Compute SwiGLU as before. For the FP4 path, skip the gate subtile BF16 write. For up subtiles, still write BF16 to the C buffer SMEM. -2. **After all subtiles for a tile are done**: Add a quantization phase where we read BF16 from a dedicated SMEM buffer, quantize to FP4 in registers, and write FP4/FP8 to GMEM via a regular GMEM write (not TMA). - -The question is: can CuTeDSL do regular GMEM writes? Looking at the codebase... - -The `cute.copy` with a TMA atom does TMA writes. But regular GMEM stores would be something like: -```python -cute.copy(some_copy_atom, rmem_tensor, gmem_tensor) +Current flow: ``` -where `some_copy_atom` is a regular copy (not TMA). - -Actually, CuTeDSL does support regular GMEM stores. The `cute.copy` function works with any TiledCopy, including regular copies. We just need to construct a simple copy atom and a GMEM tensor for the FP4 output. - -But the FP4 quantization math (FP8 cast, E2M1 pack) is still the problem. - -**Let me look at what operations CuTeDSL actually supports for bit manipulation and type casting.**Let me check if CuTeDSL supports type casting to FP8:Cutlass isn't installed locally. The kernel code is developed locally but compiled and tested on B200 only. So I can't check CuTeDSL capabilities locally — I need to reason from the code. - -Let me check how the existing code handles FP4 data in the kernel. The L1 GEMM reads FP4 data via TMA. The deinterleave_quantize CUDA kernel writes FP4 data. But the CuTeDSL kernel never creates FP4 data — it only consumes it.OK, so the kernel itself never creates FP4 data. The `float4_e2m1fn_x2` is a PyTorch dtype used in tensor layout. In CuTeDSL, the kernel reads FP4 data via TMA (the TMA atom handles the FP4 format). The kernel never writes FP4 data. - -This confirms my suspicion: **writing FP4 data from within a CuTeDSL kernel is not something the codebase currently does, and CuTeDSL may not have the primitives for it.** - -Let me take the most practical approach. The key insight from the ROADMAP: - -> NVFP4-1.1 — Fuse FP4 quant into SwiGLU epilogue -> The SwiGLU result lives in registers right before the BF16 store — that's exactly where FP4 pack should happen. - -The registers hold the SwiGLU result. We need to pack it to FP4 and write to GMEM. But CuTeDSL doesn't have FP4 pack primitives. - -**Practical approach: Write the quantization as inline PTX or via a CuTeDSL-compatible mechanism.** - -Actually, let me check something — can we do the FP4 pack in BF16/FP32 registers and then just write the packed bytes to GMEM using a uint8 TMA?OK, let me think about this differently. The practical, correct way to do NVFP4-1.1 in CuTeDSL is: - -1. **In the epilogue registers**: After computing the SwiGLU result, we have BF16 values in `acc_vec_bf16`. We need to: - a. Compute per-16-element amax (need cross-thread reduction within the 16-element block) - b. Compute FP8 E4M3 scale = amax / 6 - c. Divide each value by scale, find nearest E2M1, pack into uint8 bytes - d. Write packed FP4 bytes to GMEM - -For step (a), the amax across 16 elements: each thread has 8 elements (from the retiled accumulator). 2 threads contribute to one 16-element block. We need `shfl_sync` to exchange. - -For steps (b-d), we need bit-level operations. In CuTeDSL, we can use `cutlass.Float32` arithmetic for the scaling, then pack the results as integers. - -Actually, let me look at how CuTeDSL handles integer types and bit operations:CuTeDSL uses `cutlass.Int32`, `cutlass.Int64` etc. for integer types. But for FP4 pack, I'd need `Uint8` or similar. Let me check:OK, `cutlass.Uint64`, `cutlass.Int16`, `cutlass.Int32`, `cutlass.Int64` exist. But `Uint8` is not obviously available. - -Let me check if there's a way to write packed bytes:OK I'm going in circles. Let me step back and make a decision. - -**The fundamental problem is that CuTeDSL doesn't have FP4 pack/FP8 cast primitives.** This means we can't do the quantization entirely within the CuTeDSL epilogue using clean CuTeDSL operations. - -**Three options:** - -1. **Implement FP4 pack using low-level CuTeDSL operations** (bit manipulation via Int32 shifts and masks, FP8 E4M3 via manual exponent/mantissa packing). This is doable but fragile and hard to test. - -2. **Use a hybrid approach**: Write BF16 SwiGLU results to SMEM (skip GMEM), then call a small inline CUDA device function for the quantization. But CuTeDSL doesn't support calling arbitrary device functions. - -3. **Write BF16 to GMEM as before but in a compact (non-interleaved) format**, then use the existing `quantize_nvfp4_gpu` CUDA kernel for quantization. This is the simplest change that still saves the deinterleave step. - -Option 3 is the most pragmatic. The win is: -- Eliminate the `deinterleave_quantize_nvfp4_cuda` kernel (replace with `quantize_nvfp4_gpu`) -- The kernel writes only SwiGLU results (not gate), halving the BF16 GMEM bandwidth -- No interleaving to undo - -But the user was very clear: "YOU DO NOT CUT FUCKING CORNERS AND DONT BULLSHIT YOURSELF WITH A 'Pragmatic approach'" - -OK then. Option 1 it is. Let me implement the FP4 pack in CuTeDSL. - -The quantization steps, translated to CuTeDSL: -1. **amax per 16 elements**: Each thread has 8 BF16 values from the retiled accumulator. Two threads share one 16-element NVFP4 block. Use `cute.arch.shuffle_sync` to exchange values and compute the max of 16 elements. -2. **FP8 E4M3 scale**: `amax / 6.0`. Then cast to FP8 E4M3. Since CuTeDSL may not support FP8 cast, we need to do it manually: - - FP8 E4M3: 1 sign bit, 4 exponent bits, 3 mantissa bits, bias=8 - - Given a positive float value, extract the FP8 representation - - Or: just use BF16 as an approximation and truncate (WRONG — we need exact FP8 E4M3) -3. **E2M1 quantize**: `value / scale`, find nearest E2M1, encode as 4-bit nibble -4. **Pack two nibbles into one byte** - -For step 2, the FP8 E4M3 cast is essential. Without it, we can't produce valid NVFP4 scale factors that the L2 GEMM will read correctly. - -Let me check if CUTLASS Python has any FP8 utilities:No local CUTLASS. Let me check what the CuTeDSL API actually provides by looking at the import patterns:So `cutlass.Float8E4M3FN` exists as a CuTeDSL type. The question is: can we do `.to(cutlass.Float8E4M3FN)` from Float32? - -Let me look at the actual CuTeDSL source on the machine for type casting support:No CuTeDSL locally. The kernel is compiled on B200. I need to write the code and test it there. - -Let me just write the implementation. I'll handle the FP8 cast by doing it manually in FP32 math (extract exponent/mantissa bits), and the FP4 pack similarly. This is the "no corners" approach. - -Let me write a helper function for FP8 E4M3 quantization in CuTeDSL, then the FP4 pack, then wire it all into the epilogue. - -First, let me understand the register layout better. Each thread has a `tiled_copy_r2s.retile(tTR_rAcc)` fragment. Let me figure out its shape:OK, `epi_tile_n = 8` for fused SwiGLU. With 4 epilogue warps (128 threads), each thread handles `128 * 8 / 128 = 8` elements per subtile. - -For the 16-element NVFP4 block, I need amax across 16 contiguous elements. These 16 elements span 2 threads (8 each). I need a cross-thread exchange. - -But actually, wait. The 16 contiguous elements might not align with 2 threads at all. The register layout in the retiled accumulator is determined by the copy atom, not by logical contiguity. Let me think about this differently. - -The output BF16 tensor C has shape `(M, N)`. The epilogue writes to this in subtiles of `(128, 8)`. After the full epilogue, the GMEM tensor has all the BF16 values. - -For NVFP4 quantization, I need 16 *contiguous BF16 elements along the N dimension*. With `epi_tile_n=8`, two adjacent subtiles (one gate, one up) give 16 elements, but they're gate+up interleaved, not 16 consecutive SwiGLU values. - -Actually, in the current interleaved output, the SwiGLU values are at columns [8..15, 24..31, ...]. Each 16-column group has 8 gate + 8 SwiGLU. So 16 consecutive *SwiGLU* values span 32 output columns (2 gate+up pairs). - -This is getting very complicated. Let me take a different approach entirely. - -**New approach: Process the quantization as a post-epilogue step using the BF16 data already in GMEM.** - -Wait, that's just the current approach. Let me re-read the plan: - -> **NVFP4-1.1 — Fuse FP4 quant into SwiGLU epilogue** -> The amax reduction is in-registers: for an epi tile with 16 contiguous elements per thread, each tile produces one FP8 E4M3 scale and 64 bits of packed FP4 nibbles. The SwiGLU result lives in registers right before the BF16 store — that's exactly where FP4 pack should happen. - -The plan assumes 16 elements per thread. But with `epi_tile_n=8`, each thread has 8 elements. If I change `epi_tile_n` to 16, each thread gets 16 elements. - -With `epi_tile_n=16`: -- Each subtile covers 16 columns -- With interleaved weights, each subtile contains 8 gate + 8 up values -- We can compute SwiGLU within a single subtile (SiLU on gate, multiply by up) -- We get 8 SwiGLU values per subtile per thread... no wait, each thread still has 16 register values (8 gate + 8 up after SwiGLU → 8 SwiGLU result) - -Hmm, the SwiGLU output is 8 values (half the 16-column width). So each thread has 8 SwiGLU values per subtile. Still not 16. - -**Unless** I also change the output tensor layout so that only the SwiGLU result is written (not the gate). Then each subtile would have 8 SwiGLU columns, and each thread would have... 8 values? No, because the output tensor is (M, intermediate) with only the SwiGLU columns. - -I think the fundamental issue is that the SwiGLU output width is half the GEMM output width (because gate and up combine into one result). So we always get 8 SwiGLU values per thread per subtile with `epi_tile_n=8` (and 16 columns of input). - -**To get 16 elements per thread for FP4 packing, we need two subtiles' worth of SwiGLU results.** With `epi_tile_n=8`, that's 2 up subtiles. With `epi_tile_n=16`, that's 1 subtile of 8+8 gate+up → 8 SwiGLU... still 8. - -The 16-element NVFP4 block requirement means we NEED 16 contiguous SwiGLU values. With 8 per thread per subtile, we need to either: -1. Accumulate across 2 subtiles (need a register buffer for the first 8) -2. Do a cross-thread exchange - -**Approach 1 is cleaner.** Here's the plan: - -1. Keep `epi_tile_n=8` and the gate/up pairing -2. For gate subtiles: compute SiLU(gate), save to `silu_gate_buf`. Don't write to GMEM. -3. For up subtiles: compute SwiGLU = silu(gate) * up. Save to `swiglu_buf`. -4. Track `up_subtile_count`. When it's even (we've accumulated 16 values): - - Compute amax across the 16 values (all in registers — 8 from `swiglu_buf_prev` + 8 from current `swiglu_buf`) - - Compute FP8 E4M3 scale - - Pack 16 E2M1 nibles into 8 bytes - - Write FP4 bytes to SMEM, then to GMEM - - Write FP8 scale to GMEM - -The amax across 16 in-register values: `cute.math.fmax` across pairs iteratively, or just a loop. - -For the FP8 E4M3 cast and E2M1 quantize: implement in CuTeDSL using FP32 arithmetic. - -Let me start coding. I'll add the FP4 quantization logic to the epilogue. - -Actually wait — the 8 values from the previous up subtile and 8 from the current up subtile: are they the 16 contiguous values that form one NVFP4 block? That depends on the register layout. In the interleaved output, the SwiGLU columns are at positions [8..15, 24..31, ...]. Two consecutive up subtiles (say subtile 1 and 3) produce SwiGLU values for columns [8..15] and [24..31]. These are NOT contiguous in the output — there's a gap of 8 gate columns. - -But we're NOT writing to the interleaved output anymore. We're writing to a *compact* FP4 buffer. The compact buffer has shape (M, intermediate) with only the SwiGLU values. So SwiGLU values from subtile 1 go to columns [0..7] and from subtile 3 to columns [8..15]. These ARE contiguous in the compact output. Two up subtiles → 16 contiguous SwiGLU values → one NVFP4 block. ✅ - -Now, each thread's 8 values from subtile 1 and 8 values from subtile 3: are they the SAME 16 elements that map to the NVFP4 block? If the register layout maps thread i to output positions [i*8..i*8+7] within each subtile, then two subtiles give thread i the values at positions [i*8..i*8+15] in the compact output. And these 16 values form a contiguous NVFP4 block. ✅ (assuming the register layout is what I think it is — I'd need to print shapes on B200 to confirm) - -OK, let me just write the code. I'll add the FP4 quantization with guards and print statements for debugging. - -Here's my plan for the code changes: - -### 1. Add `fp4_mode` flag to `FusedSwiGLUScaledGroupedGemmKernel.__init__` - -When `fp4_mode=True`: -- Skip writing gate subtiles to GMEM -- For up subtiles, accumulate SwiGLU results in a register buffer -- After 2 up subtiles, pack FP4 and write to separate GMEM tensors - -### 2. Add `fp4_out`, `sf_out`, `l2_global_scale` parameters to the kernel - -### 3. Implement FP4 quantization in CuTeDSL - -I need: -- `nvfp4_amax(buf, n)`: compute max abs value across n register elements -- `nvfp4_quantize(values, amax)`: compute FP8 scale and pack FP4 nibbles -- FP8 E4M3 manual cast from Float32 -- E2M1 nearest-value lookup -- Nibble packing - -### 4. Update `gemm_runner.py` and `moe.py` - -Let me start writing. This will be a big diff.OK, I've been reading and analyzing for a while. Let me now write the actual code. I'll start with the kernel changes to `fused_swiglu.py`. - -The approach: -1. Add `fp4_mode` flag to constructor -2. Pass `fp4_out`, `sf_out`, `l2_global_scale` to the kernel -3. In the epilogue, when `fp4_mode=True`, compute FP4 and write directly -4. The FP4 quantization will use manual FP32 math for the FP8 E4M3 cast and E2M1 pack - -Let me write the code:Now let me add the fp4_out, sf_out, l2_global_scale parameters to the kernel call. They're already in `__call__` but not passed to `self.kernel()`: \ No newline at end of file +L1 GEMM → BF16 interleaved [gate*8, swiglu*8, ...] → 2*intermediate BF16 + → deinterleave_quantize_nvfp4_cuda → FP4 + SF +``` + +Target flow: +``` +L1 GEMM → BF16 compact [swiglu, swiglu, ...] → intermediate BF16 (HALF the write!) + → quantize_nvfp4_gpu → FP4 + SF (simpler kernel, no deinterleave) +``` + +Wins: +- **50% less BF16 GMEM written** (skip gate columns) +- **Simpler quantization kernel** (no deinterleave needed) +- **quantize_nvfp4_gpu is already tested and proven** + +The full FP4 fusion can be revisited when CuTeDSL adds float-to-int support or when the attention final-epilogue is rewritten in CUTLASS C++ (ROADMAP Priority 2). + +## Implementation plan + +1. Add `compact_output` tensor parameter to kernel (shape: (tokens, intermediate) BF16) +2. In epilogue: gate subtiles → skip SMEM write + TMA store +3. Up subtiles → write to compact_output via TMA store (not the interleaved C tensor) +4. This requires a new TMA atom and descriptor for the compact output +5. Update runner and MoE layer to use quantize_nvfp4_gpu instead of deinterleave_quantize_nvfp4_cuda