Commit Graph

2095 Commits

Author SHA1 Message Date
e0f60b9f05 Fix fused router: plain ints for mma_tiler + @cute.jit pattern
Root cause of previous crash: cutlass.Int32(128) wrapping of mma_inst_shape_mn
caused _unpack_x_tuple to fail in cute.size(tiled_mma.shape_mnk, mode=[2]).

The fused_swiglu kernel uses plain Python ints for mma_tiler_mnk and
mma_inst_shape_mn — NOT cutlass.Int32. Inside @cute.jit, CuTeDSL
auto-converts plain ints to MLIR values. The Int32 wrapping was unnecessary
and actually harmful.

Pattern: same as fused_swiglu.py __call__:
- @cute.jit compiled_fn takes CuTe tensors
- _setup_attributes called inside JIT (needs MLIR context)
- cute.compile at the end
2026-06-01 10:37:15 +00:00
057ae2101e CRITICAL FIX: Move tiled_mma creation and _setup_attributes OUTSIDE @cute.jit
The _setup_attributes() calls cute.size(tiled_mma.shape_mnk, mode=[2])
which requires host-side execution. Inside @cute.jit, tiled_mma.shape_mnk
returns MLIR values that can't be unpacked by cute.size().

This follows the fused_swiglu.py pattern exactly: setup on host side,
then pass everything to the kernel. Removed @cute.jit wrapper entirely
in favor of direct kernel launch (same as fused_swiglu).
2026-06-01 10:28:01 +00:00
71deeb91a9 Quantize BF16 gate weight to NVFP4 for fused router + add global scales to GEMM
CRITICAL: Checkpoint stores gate weights as BF16, not NVFP4.
Previous code fell back to BF16 cuBLAS because weight_scale was missing.
Now we quantize the BF16 gate weight to NVFP4 at load time using
quantize_to_nvfp4() and pass the result to the fused router kernel.

Also added global scale (gsa, gsb) parameters to the kernel:
- gsa (activation global scale) applied during activation quantization
- gsb (weight global scale) applied in epilogue before sqrt(softplus)
- The MMA output is (A * SFA) @ (B * SFB), missing gsa*gsb
- Epilogue now computes sqrt(softplus(logit * gsa * gsb))
  instead of sqrt(softplus(logit))
2026-06-01 10:14:29 +00:00
24fed15ed6 Fix: convert PyTorch tensors to CuTe tensors for fused router kernel
- Added cutlass_torch.from_dlpack() + mark_layout_dynamic() conversions
- quantize_activation_nvfp4 returns (fp4_packed, fp8_scales) which are
  converted to CuTe tensors before passing to the kernel
- Same pattern as gemm_runner.py
2026-06-01 10:02:40 +00:00
bab748763e Rewrite NVFP4 fused router kernel: MoE-style epilogue replaces broken SMEM merge
CRITICAL REWRITE of nvfp4_fused_router_kernel.py:
- REMOVED: Raw pointer SMEM merge (storage.merge_scores.data_ptr()[idx] = val)
  This crashed the CuTeDSL MLIR optimizer. Never use raw pointer indexing
  inside CuTeDSL kernels.
- REMOVED: Per-thread top-k accumulation + 128-thread SMEM merge. Too complex
  for MLIR, caused SIGABRT during compilation.
- ADDED: MoE-style epilogue (TMEM→regs→activation→SMEM→TMA store→GMEM)
  using paired copy atoms from CUTLASS (epilogue_tmem_copy_and_partition +
  epilogue_smem_copy_and_partition). Structurally identical to the proven
  FusedSwiGLUScaledGroupedGemmKernel epilogue. This SHOULD compile.
- Activation: sqrt(softplus(logit)) in registers (replaces SwiGLU)
- Output: FP32 activated scores written to GMEM via TMA store
- Top-k handled by activation_topk CUDA kernel in Python wrapper

Other changes:
- _activation_topk.py: Added run_fused_activation_topk_pre_activated() for
  top-k + renorm on pre-activated scores (PyTorch reference, not CUDA kernel)
- dense_router_dispatch_nvfp4_fused: Updated to match new kernel API
- Kernel now uses standard _compute_stages() for SMEM budget calculation
- Kernel now uses compute_epilogue_tile_shape() for epi_tile (not hardcoded)
- C pipeline (PipelineTmaStore) added for SMEM→GMEM overlap
2026-06-01 09:59:34 +00:00
31ebe4f2db Wire NVFP4 fused router kernel into e2e single-shot pipeline
- Add dense_router_dispatch_nvfp4_fused() in dense_router_decode.py:
  single-kernel NVFP4 blockscaled GEMM + fused router epilogue
- Router.load_nvfp4_fused_gate(): stores raw NVFP4 tensors for fused path
- Router._run_dense_impl() dispatch priority: fused > 2-kernel > BF16
- single_shot_inference.py: loads raw NVFP4 gate weights for fused kernel
  instead of building Nvfp4Linear (which was the 2-kernel path)
- Fix selection sort bug in nvfp4_fused_router_kernel.py: pass 0 was
  missing t_s/t_i/t_a temp save before swap, causing undefined vars
- Export dense_router_dispatch_nvfp4_fused from __init__.py
2026-06-01 09:47:48 +00:00
d9d3ca42b0 Fix: mma_tiler and cluster_layout must use MLIR values for cute.slice_
cute.slice_ on Python int tuples fails. All values in mma_tiler and
cluster_layout need to be cutlass.Int32() since they flow into
cute.slice_ and cute.local_tile inside @cute.kernel.

Now consistent: mma_inst_shape_mn, mma_tiler, cluster_layout_vmnk all
use MLIR-typed values created inside @cute.jit context.
2026-06-01 09:42:17 +00:00
ec79f30709 Fix: PersistentTileSchedulerParams cluster_shape must be Python ints not MLIR values 2026-06-01 09:38:08 +00:00
28d0cb4f41 Revert cutlass.Int32 wrapping — now inside @cute.jit, cute.round_up works
All CuTe DSL calls now happen inside @cute.jit context, so
cute.round_up and all layout operations have proper MLIR context.
No need for manual Int32 wrapping or Python math workarounds.
2026-06-01 09:35:03 +00:00
b536f99192 CRITICAL FIX: move ALL CuTe DSL setup inside @cute.jit context
The root cause of ALL the MLIR crashes: _create_tiled_mma and
_setup_attributes call cute.make_tiled_mma, sm100_utils.make_smem_layout_a,
etc. These are MLIR operations that REQUIRE an active MLIR context.

Previously they ran in run() OUTSIDE @cute.jit, so there was no MLIR
context — causing 'Expected an MLIR object (got None)' in _pack_shape.

Now ALL CuTe DSL calls happen INSIDE the @cute.jit function, matching
fused_swiglu's pattern where __call__ is called from JIT context.

Grid computation uses plain Python math (no MLIR needed).
2026-06-01 09:32:05 +00:00
65669596d4 Fix: all CuTe shape values must be cutlass.Int32 for MLIR compatibility
Python ints cause 'Expected an MLIR object (got None)' in _pack_shape.
This is the same fix we applied to the FMHA kernel mma_tiler.
All mma_inst_shape, mma_tiler, cluster_shape values now use cutlass.Int32().
2026-06-01 09:30:15 +00:00
df48dacc2b Fix: set mma_inst_shape_mn in __init__ before _create_tiled_mma call 2026-06-01 09:22:24 +00:00
28f78420c2 Fix: quantize_activation_nvfp4 API - correct signature and return values 2026-06-01 09:21:04 +00:00
7b3f6cb13c Fix fused router: use run_nvfp4_fused_router wrapper, correct CuTe tensor API
- kernel wrapper converts torch tensors to CuTe tensors with mark_layout_dynamic
- test uses the wrapper instead of calling kernel.run() directly
- mat_b/scale_b are now torch tensors (converted inside wrapper)
2026-06-01 09:19:48 +00:00
483e759d53 Fix: use tensor.mark_layout_dynamic() method (not cute.mark_layout_dynamic) 2026-06-01 09:16:33 +00:00
2412745b21 Test fix: slice NVFP4 logits to actual expert count (GEMM padding) 2026-06-01 09:15:06 +00:00
f33ca41c2a Fused router: replace nested if/else top-k with flat find-min-replace approach
The 5-level nested if/else for sorted insertion created O(2^5) MLIR
regions that crashed the CuTeDSL MLIR optimizer (SIGABRT).

New approach:
- Find-min-replace: scan 6 entries to find minimum (sequential, 1-level nesting)
- Replace the minimum if new score > min (flat conditionals by index)
- Selection sort the final 6 entries after SMEM merge (descending order)
- All conditionals are FLAT (at most 1 level of nesting)

This should avoid the MLIR optimizer explosion while producing
identical results.
2026-06-01 09:13:53 +00:00
4f4ae8febd Test: enumerate CuTeDSL math API to check available operations 2026-06-01 09:11:29 +00:00
9b86b2b414 Test: fix fused router test - proper NVFP4 quantization and CuTe tensor setup
- Use quantize_to_nvfp4 for weight quantization
- Use quantize_activation_nvfp4 with computed global_scale
- Get mat_b and scale_b from Nvfp4Linear after finalize_weights
- Compare against both BF16 reference and NVFP4 GEMM reference
2026-06-01 08:56:20 +00:00
b94f8d4ed8 Test: fused router kernel vs BF16 reference path
- BF16 GEMM + activation_topk as reference
- NVFP4 GEMM + fused router epilogue as test target
- Proper NVFP4 quantization and CuTe tensor creation
- Cosine similarity and topk_ids matching validation
2026-06-01 08:54:24 +00:00
2433700a69 Fused router kernel: rewrite epilogue with proper CuTeDSL constructs
- Replace Python lists with individual scalar variables (s0..s5, i0..i5, a0..a5)
- Replace min-heap sift-down with fully unrolled sorted insertion
  (descending order, no dynamic indexing, no while loops)
- Replace raw SMEM pointer arithmetic with CuTeDSL SMEM tensors
  (s_merge_s, s_merge_i, s_merge_a)
- Replace cute.where with cute.math.fmax
- Fix expert index calculation: col + tile_n_offset + subtile_idx * epi_n
- Top-6 accumulates across all N-tiles (for E=384 with 3 tiles of 128)
- Add iter_acc_early_release for overlapping accumulator
- Rewrite test to compare fused kernel vs 2-kernel reference path
- Remove stale memory doc
2026-06-01 08:49:39 +00:00
d01b4b02de Complete NVFP4 fused router kernel: full MMA + router epilogue
- TMA warp: persistent tile scheduling + TMA loads for A/B/SFA/SFB
- MMA warp: blockscaled GEMM (tcgen05.mma.block_scale) with S2T copy
  for SFA/SFB, proper pipeline synchronization (AB + Acc pipelines)
- Epilogue warps: TMEM->register via epilogue_tmem_copy_and_partition,
  sqrt(softplus) + e_bias + min-heap top-k + renormalization
- Python wrapper: run_nvfp4_fused_router() with proper CuTe tensor
  creation via from_dlpack + mark_layout_dynamic
- Single-kernel path, no BF16 fallback, no intermediate GMEM buffer
- Following exact patterns from MoE fused_swiglu.py kernel
2026-06-01 08:37:10 +00:00
25b9a5f32d Fix test: use from_dlpack for c_tensor 2026-06-01 07:55:29 +00:00
d2819fc39c Fix test: use as_tensor instead of make_tensor 2026-06-01 07:54:36 +00:00
5ea71ebd78 Add NVFP4 CuTeDSL compilation test (verify MmaMXF4NVF4Op compiles) 2026-06-01 07:53:43 +00:00
fa6dbd4aa2 WIP: Rewrite NVFP4 fused router in CuTeDSL with MmaMXF4NVF4Op (sf_vec_size=16)
Uses kind::mxf4nvf4 — native NVF4 with E2M1 microscales, 16-elem blocks.
NO MXFP4, NO CONVERSIONS.

Kernel incomplete — GEMM mainloop mirrors dense.py but epilogue is TODO.
Need to verify CuTeDSL compilation works with proper PipelineTmaUmma/
PipelineUmmaAsync abstractions before adding top-k epilogue.
2026-06-01 07:53:21 +00:00
4f706b55d7 Remove raw CUDA C++ fused router and DeepGEMM (MXFP4, wrong instruction)
DeepGEMM uses kind::mxf4.block_scale.block32 (MXFP4, UE8M0 scales, 32-elem blocks).
DSV4 uses NVF4: kind::mxf4nvf4 (E2M1 microscales, 16-elem blocks).
Using MXFP4 would require E2M1->UE8M0 conversion. NO CONVERSIONS.

Rewriting fused router in CuTeDSL with MmaMXF4NVF4Op (sf_vec_size=16).
2026-06-01 07:51:31 +00:00
424fe6bf2c Fix: use SM100_MMA_MXF8F6F4_SS (not MXF4) to match Nvfp4Linear path
MXF4 has .block32 hardcoded. MXF8F6F4 matches what CuTeDSL generates
via make_instr_desc_block_scaled. Both use E2M1 data + UE8M0 scales
at hardware level. NVFP4 E2M1 microscales are combined into UE8M0
during quantization — no MXFP4 conversion.
2026-06-01 07:44:53 +00:00
2e2caadf7d WIP: NVFP4 fused router kernel in raw CUDA C++ using DeepGEMM primitives
- nvfp4_fused_router_kernel.cuh: 1-CTA NVFP4 GEMM + sqrt(softplus) + top-k epilogue
- Uses DeepGEMM SM100 primitives: SM100_MMA_MXF4_SS, UTCCP, UMMA descriptors
- 4 warp roles: TMA load, UTCCP transpose, MMA issue, epilogue
- nvfp4_fused_router_cuda.py: Python wrapper (TMA descriptor setup TBD)

NOT YET COMPILING - needs:
1. SMEM layout fix (single extern __shared__)
2. TMA descriptor creation (cuTensorMapEncodeTiled)
3. Top-k cross-warp merge completion
4. FP4 tensor format alignment with DeepGEMM
2026-06-01 07:41:42 +00:00
e3ea609ddd Embed DeepGEMM source (not submodule) for SM100 raw CUDA GEMM primitives 2026-06-01 07:39:40 +00:00
dae83723a3 Add DeepGEMM as third-party dependency for SM100 raw CUDA GEMM primitives 2026-06-01 07:39:38 +00:00
ef4c0ad489 Fix BF16 router mma_tiler: use cutlass.Int32 for CuTe DSL compatibility 2026-06-01 07:29:30 +00:00
79be9cb8da Fix: hardcode mma_inst_shape_k=32 for NVFP4 (avoids MLIR unpack error in JIT) 2026-06-01 07:20:23 +00:00
c3a64ceed7 Fix: mma_tiler must use CuTe Ints for static layout construction 2026-06-01 07:19:15 +00:00
39b481e52b Ensure mma_tiler contains CuTe Ints for cute.slice_ compatibility 2026-06-01 07:16:47 +00:00
57cc20d5ad Fix SFA/SFB SMEM: blockscaled layouts are plain Layout (no .outer/.inner swizzle) 2026-06-01 07:14:45 +00:00
fcd7680583 Fix CuTe tensor creation: use from_dlpack + mark_layout_dynamic 2026-06-01 07:12:52 +00:00
3a8c6daeb3 Fix: cutlass_torch.make_tensor -> as_tensor 2026-06-01 07:11:43 +00:00
0553117af6 Simplify fused router test: compare fused vs 2-kernel NVFP4 path 2026-06-01 07:10:55 +00:00
44a0e59808 Fix fused router test: use quantize_weight_to_nvfp4 (correct function name) 2026-06-01 07:08:56 +00:00
940f37fb6c NVFP4 fused router kernel: full rewrite with proper block-scaled GEMM setup
Major fixes:
- Added tiled_mma_sfb creation (always CtaGroup.ONE, rounded N)
- Added mma_tiler_sfb, cta_tile_shape_mnk_sfb, cluster_layout_sfb_vmnk
- Use blockscaled_utils.make_smem_layout_sfa/sfb (with sf_vec_size)
  instead of sm100_utils (which doesn't support block-scaled SF layouts)
- Proper TMEM column accounting for SFA + SFB + accumulator
- Fixed make_blockscaled_trivial_tiled_mma argument order
  (a_dtype, b_dtype, a_major, b_major, sf_dtype, sf_vec_size, cta_group, mma_inst_shape)
- Fixed SFB TMA atom to use tiled_mma_sfb and cluster_layout_sfb_vmnk
- Fixed SFB partition_SFB to use tiled_mma_sfb.get_slice
- Fixed SFB global tile partitioning to use mma_tiler_sfb
- Fixed mainloop_s2t_copy_and_partition to use TMEM fragments
  (make_fragment_SFA/SFB) as the tSF parameter
- Updated run_nvfp4_fused_router wrapper to accept processed weight
  tensors from Nvfp4Linear._mat_b and _scale_b
- Updated test to properly build Nvfp4Linear and use processed weights

The old code was a rough sketch that never worked — it was missing
the entire tiled_mma_sfb infrastructure, used wrong SMEM layout
functions, and had broken TMA atom setup for scale factors.
v-nvfp4-fused-router-rewrite-20260601-0715
2026-06-01 07:08:12 +00:00
8658c8eca5 fix: add sf_vec_size parameter back to Nvfp4FusedRouterKernel __init__ 2026-06-01 07:01:02 +00:00
b97f30e289 fix: store sf_vec_size as instance variable 2026-06-01 06:56:33 +00:00
c225d195ea fix: remove tcgen05.mma.Kind (doesn't exist), use make_blockscaled_trivial_tiled_mma 2026-06-01 06:54:49 +00:00
e6803b450d rewrite: simplified fused router test (reference + import check) 2026-06-01 06:53:17 +00:00
262cec262d fix: add shape assertions to fused router test 2026-06-01 06:51:47 +00:00
db07d17a62 fix: set activation global scale in fused router test 2026-06-01 06:50:41 +00:00
2abb4a19d9 fix: set gs and ws2 fields for Nvfp4Linear in fused router test 2026-06-01 06:49:43 +00:00
61c04f7152 fix: Nvfp4Linear field is sf not scale_b 2026-06-01 06:48:39 +00:00
982f245c67 fix: use correct Nvfp4Linear field names (fp4, scale_b, gsb) 2026-06-01 06:47:15 +00:00