Root cause of previous crash: cutlass.Int32(128) wrapping of mma_inst_shape_mn
caused _unpack_x_tuple to fail in cute.size(tiled_mma.shape_mnk, mode=[2]).
The fused_swiglu kernel uses plain Python ints for mma_tiler_mnk and
mma_inst_shape_mn — NOT cutlass.Int32. Inside @cute.jit, CuTeDSL
auto-converts plain ints to MLIR values. The Int32 wrapping was unnecessary
and actually harmful.
Pattern: same as fused_swiglu.py __call__:
- @cute.jit compiled_fn takes CuTe tensors
- _setup_attributes called inside JIT (needs MLIR context)
- cute.compile at the end
The _setup_attributes() calls cute.size(tiled_mma.shape_mnk, mode=[2])
which requires host-side execution. Inside @cute.jit, tiled_mma.shape_mnk
returns MLIR values that can't be unpacked by cute.size().
This follows the fused_swiglu.py pattern exactly: setup on host side,
then pass everything to the kernel. Removed @cute.jit wrapper entirely
in favor of direct kernel launch (same as fused_swiglu).
CRITICAL: Checkpoint stores gate weights as BF16, not NVFP4.
Previous code fell back to BF16 cuBLAS because weight_scale was missing.
Now we quantize the BF16 gate weight to NVFP4 at load time using
quantize_to_nvfp4() and pass the result to the fused router kernel.
Also added global scale (gsa, gsb) parameters to the kernel:
- gsa (activation global scale) applied during activation quantization
- gsb (weight global scale) applied in epilogue before sqrt(softplus)
- The MMA output is (A * SFA) @ (B * SFB), missing gsa*gsb
- Epilogue now computes sqrt(softplus(logit * gsa * gsb))
instead of sqrt(softplus(logit))
- Added cutlass_torch.from_dlpack() + mark_layout_dynamic() conversions
- quantize_activation_nvfp4 returns (fp4_packed, fp8_scales) which are
converted to CuTe tensors before passing to the kernel
- Same pattern as gemm_runner.py
CRITICAL REWRITE of nvfp4_fused_router_kernel.py:
- REMOVED: Raw pointer SMEM merge (storage.merge_scores.data_ptr()[idx] = val)
This crashed the CuTeDSL MLIR optimizer. Never use raw pointer indexing
inside CuTeDSL kernels.
- REMOVED: Per-thread top-k accumulation + 128-thread SMEM merge. Too complex
for MLIR, caused SIGABRT during compilation.
- ADDED: MoE-style epilogue (TMEM→regs→activation→SMEM→TMA store→GMEM)
using paired copy atoms from CUTLASS (epilogue_tmem_copy_and_partition +
epilogue_smem_copy_and_partition). Structurally identical to the proven
FusedSwiGLUScaledGroupedGemmKernel epilogue. This SHOULD compile.
- Activation: sqrt(softplus(logit)) in registers (replaces SwiGLU)
- Output: FP32 activated scores written to GMEM via TMA store
- Top-k handled by activation_topk CUDA kernel in Python wrapper
Other changes:
- _activation_topk.py: Added run_fused_activation_topk_pre_activated() for
top-k + renorm on pre-activated scores (PyTorch reference, not CUDA kernel)
- dense_router_dispatch_nvfp4_fused: Updated to match new kernel API
- Kernel now uses standard _compute_stages() for SMEM budget calculation
- Kernel now uses compute_epilogue_tile_shape() for epi_tile (not hardcoded)
- C pipeline (PipelineTmaStore) added for SMEM→GMEM overlap
- Add dense_router_dispatch_nvfp4_fused() in dense_router_decode.py:
single-kernel NVFP4 blockscaled GEMM + fused router epilogue
- Router.load_nvfp4_fused_gate(): stores raw NVFP4 tensors for fused path
- Router._run_dense_impl() dispatch priority: fused > 2-kernel > BF16
- single_shot_inference.py: loads raw NVFP4 gate weights for fused kernel
instead of building Nvfp4Linear (which was the 2-kernel path)
- Fix selection sort bug in nvfp4_fused_router_kernel.py: pass 0 was
missing t_s/t_i/t_a temp save before swap, causing undefined vars
- Export dense_router_dispatch_nvfp4_fused from __init__.py
cute.slice_ on Python int tuples fails. All values in mma_tiler and
cluster_layout need to be cutlass.Int32() since they flow into
cute.slice_ and cute.local_tile inside @cute.kernel.
Now consistent: mma_inst_shape_mn, mma_tiler, cluster_layout_vmnk all
use MLIR-typed values created inside @cute.jit context.
All CuTe DSL calls now happen inside @cute.jit context, so
cute.round_up and all layout operations have proper MLIR context.
No need for manual Int32 wrapping or Python math workarounds.
The root cause of ALL the MLIR crashes: _create_tiled_mma and
_setup_attributes call cute.make_tiled_mma, sm100_utils.make_smem_layout_a,
etc. These are MLIR operations that REQUIRE an active MLIR context.
Previously they ran in run() OUTSIDE @cute.jit, so there was no MLIR
context — causing 'Expected an MLIR object (got None)' in _pack_shape.
Now ALL CuTe DSL calls happen INSIDE the @cute.jit function, matching
fused_swiglu's pattern where __call__ is called from JIT context.
Grid computation uses plain Python math (no MLIR needed).
Python ints cause 'Expected an MLIR object (got None)' in _pack_shape.
This is the same fix we applied to the FMHA kernel mma_tiler.
All mma_inst_shape, mma_tiler, cluster_shape values now use cutlass.Int32().
- kernel wrapper converts torch tensors to CuTe tensors with mark_layout_dynamic
- test uses the wrapper instead of calling kernel.run() directly
- mat_b/scale_b are now torch tensors (converted inside wrapper)
The 5-level nested if/else for sorted insertion created O(2^5) MLIR
regions that crashed the CuTeDSL MLIR optimizer (SIGABRT).
New approach:
- Find-min-replace: scan 6 entries to find minimum (sequential, 1-level nesting)
- Replace the minimum if new score > min (flat conditionals by index)
- Selection sort the final 6 entries after SMEM merge (descending order)
- All conditionals are FLAT (at most 1 level of nesting)
This should avoid the MLIR optimizer explosion while producing
identical results.
- Use quantize_to_nvfp4 for weight quantization
- Use quantize_activation_nvfp4 with computed global_scale
- Get mat_b and scale_b from Nvfp4Linear after finalize_weights
- Compare against both BF16 reference and NVFP4 GEMM reference
- Replace Python lists with individual scalar variables (s0..s5, i0..i5, a0..a5)
- Replace min-heap sift-down with fully unrolled sorted insertion
(descending order, no dynamic indexing, no while loops)
- Replace raw SMEM pointer arithmetic with CuTeDSL SMEM tensors
(s_merge_s, s_merge_i, s_merge_a)
- Replace cute.where with cute.math.fmax
- Fix expert index calculation: col + tile_n_offset + subtile_idx * epi_n
- Top-6 accumulates across all N-tiles (for E=384 with 3 tiles of 128)
- Add iter_acc_early_release for overlapping accumulator
- Rewrite test to compare fused kernel vs 2-kernel reference path
- Remove stale memory doc
Uses kind::mxf4nvf4 — native NVF4 with E2M1 microscales, 16-elem blocks.
NO MXFP4, NO CONVERSIONS.
Kernel incomplete — GEMM mainloop mirrors dense.py but epilogue is TODO.
Need to verify CuTeDSL compilation works with proper PipelineTmaUmma/
PipelineUmmaAsync abstractions before adding top-k epilogue.
MXF4 has .block32 hardcoded. MXF8F6F4 matches what CuTeDSL generates
via make_instr_desc_block_scaled. Both use E2M1 data + UE8M0 scales
at hardware level. NVFP4 E2M1 microscales are combined into UE8M0
during quantization — no MXFP4 conversion.
Major fixes:
- Added tiled_mma_sfb creation (always CtaGroup.ONE, rounded N)
- Added mma_tiler_sfb, cta_tile_shape_mnk_sfb, cluster_layout_sfb_vmnk
- Use blockscaled_utils.make_smem_layout_sfa/sfb (with sf_vec_size)
instead of sm100_utils (which doesn't support block-scaled SF layouts)
- Proper TMEM column accounting for SFA + SFB + accumulator
- Fixed make_blockscaled_trivial_tiled_mma argument order
(a_dtype, b_dtype, a_major, b_major, sf_dtype, sf_vec_size, cta_group, mma_inst_shape)
- Fixed SFB TMA atom to use tiled_mma_sfb and cluster_layout_sfb_vmnk
- Fixed SFB partition_SFB to use tiled_mma_sfb.get_slice
- Fixed SFB global tile partitioning to use mma_tiler_sfb
- Fixed mainloop_s2t_copy_and_partition to use TMEM fragments
(make_fragment_SFA/SFB) as the tSF parameter
- Updated run_nvfp4_fused_router wrapper to accept processed weight
tensors from Nvfp4Linear._mat_b and _scale_b
- Updated test to properly build Nvfp4Linear and use processed weights
The old code was a rough sketch that never worked — it was missing
the entire tiled_mma_sfb infrastructure, used wrong SMEM layout
functions, and had broken TMA atom setup for scale factors.