Root cause of previous crash: cutlass.Int32(128) wrapping of mma_inst_shape_mn
caused _unpack_x_tuple to fail in cute.size(tiled_mma.shape_mnk, mode=[2]).
The fused_swiglu kernel uses plain Python ints for mma_tiler_mnk and
mma_inst_shape_mn — NOT cutlass.Int32. Inside @cute.jit, CuTeDSL
auto-converts plain ints to MLIR values. The Int32 wrapping was unnecessary
and actually harmful.
Pattern: same as fused_swiglu.py __call__:
- @cute.jit compiled_fn takes CuTe tensors
- _setup_attributes called inside JIT (needs MLIR context)
- cute.compile at the end
The _setup_attributes() calls cute.size(tiled_mma.shape_mnk, mode=[2])
which requires host-side execution. Inside @cute.jit, tiled_mma.shape_mnk
returns MLIR values that can't be unpacked by cute.size().
This follows the fused_swiglu.py pattern exactly: setup on host side,
then pass everything to the kernel. Removed @cute.jit wrapper entirely
in favor of direct kernel launch (same as fused_swiglu).
CRITICAL: Checkpoint stores gate weights as BF16, not NVFP4.
Previous code fell back to BF16 cuBLAS because weight_scale was missing.
Now we quantize the BF16 gate weight to NVFP4 at load time using
quantize_to_nvfp4() and pass the result to the fused router kernel.
Also added global scale (gsa, gsb) parameters to the kernel:
- gsa (activation global scale) applied during activation quantization
- gsb (weight global scale) applied in epilogue before sqrt(softplus)
- The MMA output is (A * SFA) @ (B * SFB), missing gsa*gsb
- Epilogue now computes sqrt(softplus(logit * gsa * gsb))
instead of sqrt(softplus(logit))
- Added cutlass_torch.from_dlpack() + mark_layout_dynamic() conversions
- quantize_activation_nvfp4 returns (fp4_packed, fp8_scales) which are
converted to CuTe tensors before passing to the kernel
- Same pattern as gemm_runner.py
CRITICAL REWRITE of nvfp4_fused_router_kernel.py:
- REMOVED: Raw pointer SMEM merge (storage.merge_scores.data_ptr()[idx] = val)
This crashed the CuTeDSL MLIR optimizer. Never use raw pointer indexing
inside CuTeDSL kernels.
- REMOVED: Per-thread top-k accumulation + 128-thread SMEM merge. Too complex
for MLIR, caused SIGABRT during compilation.
- ADDED: MoE-style epilogue (TMEM→regs→activation→SMEM→TMA store→GMEM)
using paired copy atoms from CUTLASS (epilogue_tmem_copy_and_partition +
epilogue_smem_copy_and_partition). Structurally identical to the proven
FusedSwiGLUScaledGroupedGemmKernel epilogue. This SHOULD compile.
- Activation: sqrt(softplus(logit)) in registers (replaces SwiGLU)
- Output: FP32 activated scores written to GMEM via TMA store
- Top-k handled by activation_topk CUDA kernel in Python wrapper
Other changes:
- _activation_topk.py: Added run_fused_activation_topk_pre_activated() for
top-k + renorm on pre-activated scores (PyTorch reference, not CUDA kernel)
- dense_router_dispatch_nvfp4_fused: Updated to match new kernel API
- Kernel now uses standard _compute_stages() for SMEM budget calculation
- Kernel now uses compute_epilogue_tile_shape() for epi_tile (not hardcoded)
- C pipeline (PipelineTmaStore) added for SMEM→GMEM overlap
- Add dense_router_dispatch_nvfp4_fused() in dense_router_decode.py:
single-kernel NVFP4 blockscaled GEMM + fused router epilogue
- Router.load_nvfp4_fused_gate(): stores raw NVFP4 tensors for fused path
- Router._run_dense_impl() dispatch priority: fused > 2-kernel > BF16
- single_shot_inference.py: loads raw NVFP4 gate weights for fused kernel
instead of building Nvfp4Linear (which was the 2-kernel path)
- Fix selection sort bug in nvfp4_fused_router_kernel.py: pass 0 was
missing t_s/t_i/t_a temp save before swap, causing undefined vars
- Export dense_router_dispatch_nvfp4_fused from __init__.py
cute.slice_ on Python int tuples fails. All values in mma_tiler and
cluster_layout need to be cutlass.Int32() since they flow into
cute.slice_ and cute.local_tile inside @cute.kernel.
Now consistent: mma_inst_shape_mn, mma_tiler, cluster_layout_vmnk all
use MLIR-typed values created inside @cute.jit context.
All CuTe DSL calls now happen inside @cute.jit context, so
cute.round_up and all layout operations have proper MLIR context.
No need for manual Int32 wrapping or Python math workarounds.
The root cause of ALL the MLIR crashes: _create_tiled_mma and
_setup_attributes call cute.make_tiled_mma, sm100_utils.make_smem_layout_a,
etc. These are MLIR operations that REQUIRE an active MLIR context.
Previously they ran in run() OUTSIDE @cute.jit, so there was no MLIR
context — causing 'Expected an MLIR object (got None)' in _pack_shape.
Now ALL CuTe DSL calls happen INSIDE the @cute.jit function, matching
fused_swiglu's pattern where __call__ is called from JIT context.
Grid computation uses plain Python math (no MLIR needed).
Python ints cause 'Expected an MLIR object (got None)' in _pack_shape.
This is the same fix we applied to the FMHA kernel mma_tiler.
All mma_inst_shape, mma_tiler, cluster_shape values now use cutlass.Int32().
- kernel wrapper converts torch tensors to CuTe tensors with mark_layout_dynamic
- test uses the wrapper instead of calling kernel.run() directly
- mat_b/scale_b are now torch tensors (converted inside wrapper)
The 5-level nested if/else for sorted insertion created O(2^5) MLIR
regions that crashed the CuTeDSL MLIR optimizer (SIGABRT).
New approach:
- Find-min-replace: scan 6 entries to find minimum (sequential, 1-level nesting)
- Replace the minimum if new score > min (flat conditionals by index)
- Selection sort the final 6 entries after SMEM merge (descending order)
- All conditionals are FLAT (at most 1 level of nesting)
This should avoid the MLIR optimizer explosion while producing
identical results.
- Replace Python lists with individual scalar variables (s0..s5, i0..i5, a0..a5)
- Replace min-heap sift-down with fully unrolled sorted insertion
(descending order, no dynamic indexing, no while loops)
- Replace raw SMEM pointer arithmetic with CuTeDSL SMEM tensors
(s_merge_s, s_merge_i, s_merge_a)
- Replace cute.where with cute.math.fmax
- Fix expert index calculation: col + tile_n_offset + subtile_idx * epi_n
- Top-6 accumulates across all N-tiles (for E=384 with 3 tiles of 128)
- Add iter_acc_early_release for overlapping accumulator
- Rewrite test to compare fused kernel vs 2-kernel reference path
- Remove stale memory doc
Uses kind::mxf4nvf4 — native NVF4 with E2M1 microscales, 16-elem blocks.
NO MXFP4, NO CONVERSIONS.
Kernel incomplete — GEMM mainloop mirrors dense.py but epilogue is TODO.
Need to verify CuTeDSL compilation works with proper PipelineTmaUmma/
PipelineUmmaAsync abstractions before adding top-k epilogue.
MXF4 has .block32 hardcoded. MXF8F6F4 matches what CuTeDSL generates
via make_instr_desc_block_scaled. Both use E2M1 data + UE8M0 scales
at hardware level. NVFP4 E2M1 microscales are combined into UE8M0
during quantization — no MXFP4 conversion.
Major fixes:
- Added tiled_mma_sfb creation (always CtaGroup.ONE, rounded N)
- Added mma_tiler_sfb, cta_tile_shape_mnk_sfb, cluster_layout_sfb_vmnk
- Use blockscaled_utils.make_smem_layout_sfa/sfb (with sf_vec_size)
instead of sm100_utils (which doesn't support block-scaled SF layouts)
- Proper TMEM column accounting for SFA + SFB + accumulator
- Fixed make_blockscaled_trivial_tiled_mma argument order
(a_dtype, b_dtype, a_major, b_major, sf_dtype, sf_vec_size, cta_group, mma_inst_shape)
- Fixed SFB TMA atom to use tiled_mma_sfb and cluster_layout_sfb_vmnk
- Fixed SFB partition_SFB to use tiled_mma_sfb.get_slice
- Fixed SFB global tile partitioning to use mma_tiler_sfb
- Fixed mainloop_s2t_copy_and_partition to use TMEM fragments
(make_fragment_SFA/SFB) as the tSF parameter
- Updated run_nvfp4_fused_router wrapper to accept processed weight
tensors from Nvfp4Linear._mat_b and _scale_b
- Updated test to properly build Nvfp4Linear and use processed weights
The old code was a rough sketch that never worked — it was missing
the entire tiled_mma_sfb infrastructure, used wrong SMEM layout
functions, and had broken TMA atom setup for scale factors.
Single-kernel NVFP4 block-scaled GEMM + fused sqrt(softplus) + top-k
epilogue. Avoids materializing intermediate FP32 logits to GMEM.
Architecture: 6-warp specialization
- Warp 5 (TMA): Load A, B, SFA, SFB from GMEM → SMEM
- Warp 4 (MMA): NVFP4 block-scaled GEMM → FP32 accumulator in TMEM
- Warps 0-3 (EPI): TMEM → registers → sqrt(softplus) + bias + top-k → GMEM
Epilogue maintains per-thread min-heap across N subtiles, then
merges all 128 threads' heaps in SMEM for final top-k selection.
Mirrors Sm100BlockScaledPersistentDenseGemmKernel structure for
TMA/MMA/SFA/SFB handling, with custom top-k epilogue replacing
the standard SwiGLU + TMA store path.
NOTE: This is WIP — needs compilation testing on B200. Several
API details (tiled_mma_sfb, cluster_layout_sfb_vmnk) need to
be passed through the kernel parameters properly.
The dense router now uses NVFP4 GEMM via Nvfp4Linear for the gate
projection when NVFP4 scales are available in the checkpoint. This
replaces the BF16 cuBLAS GEMM with Blackwell SM100 tensor-core
NVFP4 acceleration.
Changes:
- dsv4/layers/router.py: add gate_lin (Nvfp4Linear) alongside W_gate
fallback. New load_nvfp4_gate() method.
- dsv4/kernels/router/dense_router_decode.py: add
dense_router_dispatch_nvfp4() using Nvfp4Linear + activation_topk
- dsv4/kernels/router/__init__.py: export new function
- single_shot_inference.py: load NVFP4 gate weights when available,
fall back to BF16 when not
The compressor_reduce.cu kernel now adds position_bias to BOTH kv and
gate values, matching the PyTorch reference. Previously the kernel only
added it to gate, and a Python workaround loop was adding it to both
before the kernel call (then passing None to the kernel).
Changes:
- compressor_reduce.cu: add position_bias to kv_val in pass 2 (CSA + HCA)
- single_shot_inference.py: remove Python position_bias loop, pass
self.ape directly to csa/hca_compress_production
- production_compress.py: already supports position_bias passthrough
- New compressor_reduce.cu: CSA/HCA token-level softmax + weighted sum + kv_norm
One block per compressed entry, 128 threads, FP32 accumulation
CSA: overlapping Ca/Cb streams (2m tokens per block)
HCA: single stream (m tokens per block)
Includes apply_kv_norm kernel (unweighted RMSNorm + weight)
- New production_compress.py: Python wrapper for CUDA kernels
- single_shot_inference.py: Compressor/Indexer now use production Nvfp4Linear
for kv_proj, gate_proj, q_b_proj, weights_proj projections
Then CUDA reduce kernel for softmax + weighted sum
No more PyTorch reference nvfp4_linear_ref in compressor/indexer path
Critical bug: checkpoint weights are (N_packed, K_packed) N-major format,
but make_b_k_major expects (E, K_packed, N_packed) input. Without the
permute, the K and N dimensions are swapped, producing garbage output
with wrong dimensions (e.g., q_a output was 3584 instead of 1536).
Also fix scale assembly: checkpoint scales are (N, K_sf) which should
use assemble_raw_scales_2d3d_3d_side (no transpose), not
assemble_scales_3d_side (which incorrectly transposes K_sf↔N).
The CuTeDSL kernel expects float4_e2m1fn_x2 dtype for FP4 weight tensors,
but checkpoint weights from safetensors are loaded as uint8. The uint8 and
float4_e2m1fn_x2 have the same byte representation, so .view() is safe.
Fixed in:
- Nvfp4Linear.finalize_weights()
- Nvfp4SharedExpert.finalize_weights()
- Nvfp4MoE._ensure_stacked() (both stacked and legacy paths)
Critical bug fix: weight_scale_2 (the second-level NVFP4 scale) was
being dropped entirely in the production pipeline. The dequant formula
is lut[w] * weight_scale * weight_scale_2, so weight_scale_2 must be
folded into the GEMM's global_scale_b parameter.
Fixes in:
- Nvfp4Linear: ws2 field, folded in finalize_weights()
- Nvfp4MoE: l1_ws2/l2_ws2 lists, folded in _ensure_stacked()
- Nvfp4SharedExpert: l1_ws2/l2_ws2 lists, folded in finalize_weights()
- single_shot_inference.py: pass weight_scale_2 through all loading paths
- Also fix missing o_a_prod key fallback in attention output