Commit Graph

90 Commits

Author SHA1 Message Date
b7c7e9fb50 refactor: clean up slot_token handling in cutlass_grouped_nvfp4_gemm
- Split provided_slot_token vs slot_token_out (returned to caller)
- No gather when slot_token=None (L2 path), no unnecessary alloc
- .contiguous() on gathered tensors for CUTLASS alignment
- Return slot_token_out consistently
2026-05-15 10:11:40 +00:00
7a1538d0c8 fix: gather on slot_token presence, add shape asserts L1→L2
- Remove torch.equal heuristic — just gather when slot_token is provided
- Add asserts for slot mapping shapes (ndim, numel == num_slots)
- Add post-L1 and pre-L2 shape asserts (l1_slots, activated, l1_fp4, l1_sf_out)
2026-05-15 10:06:07 +00:00
3cc00b12df fix: prepack cache key includes data_ptr, shape, dtype, device, N, K
Old cache used only tag ('l1'/'l2'), so layer 1 would reuse layer 0's
packed scales if the function object persisted. Now keyed by
(tag, data_ptr, shape, dtype, device, N, K) — safe across layers.
2026-05-15 10:03:37 +00:00
3ba41b9322 fix: use slot_token identity check instead of shape heuristic for gather
Shape-based check (x_fp4.shape[0] != num_slots) silently fails when
num_tokens == num_slots in L1 (topk=1). Now checks if slot_token is
the identity mapping — only gathers when slot ordering differs from
token ordering.
2026-05-15 10:00:41 +00:00
ded80be133 refactor: unify L1/L2 to use 1D slot_expert_ids consistently
Both L1 and L2 now pass pre-built 1D slot_expert_ids and slot_token to
cutlass_grouped_nvfp4_gemm instead of the 2D topk_ids.

The 2D path was broken for expert parallelism — local_mask matched ALL
local experts, producing mismatched slot_token/slot_k lengths that caused
vectorized_gather_kernel index out of bounds.

cutlass_grouped_nvfp4_gemm now:
- Takes 1D slot_expert_ids + optional slot_token
- Gathers x_fp4 by slot_token when needed (L1: tokens→slots)
- Skips gather when x_fp4 already has num_slots rows (L2)
2026-05-15 09:56:46 +00:00
c7db2242ee fix: pass slot_expert_ids directly to L2 instead of rebuilding from topk_ids
The L2 function was rebuilding slot_expert_ids by scanning topk_ids with a
local_mask. This produced mismatched slot_k (all-expert mask) vs slot_token
(rank-local mask), causing vectorized_gather_kernel index out of bounds.

Now slot_expert_local is passed directly from the outer routing logic, matching
the same slot ordering as L1.
2026-05-15 09:42:13 +00:00
f29b96de09 bug fixes 2026-05-15 09:25:33 +00:00
a780bb5fde bug fix 2026-05-15 09:11:05 +00:00
91338428d9 some optimizations 2026-05-15 09:09:35 +00:00
fae418c3a3 final scatter 2026-05-15 08:57:43 +00:00
f2cacfc2f2 fix the L2 path and the clamping math 2026-05-15 08:51:23 +00:00
d22dae2df3 were getting close 2026-05-15 08:28:40 +00:00
d493193d25 fix the god damn projections 2026-05-15 08:02:02 +00:00
9810de7109 more debug 2026-05-15 07:45:34 +00:00
1a37b66922 dang python 2026-05-15 07:23:10 +00:00
7b3a853465 more debugging 2026-05-15 07:10:13 +00:00
6b4b59c6a4 double check that weird line 2026-05-15 06:40:38 +00:00
beacc31569 is paris in the top n? 2026-05-15 06:38:20 +00:00
f17efa340d are the weights ever not zero? 2026-05-15 05:48:38 +00:00
c5d800f133 can we see the wt in? 2026-05-15 05:41:12 +00:00
6a4f52cedc god dam i just want the gemm in 2026-05-15 05:31:13 +00:00
3b3c506af5 whoops 2026-05-15 05:21:42 +00:00
76e9b078a2 more debug2 2026-05-15 05:08:53 +00:00
912e4622d7 more debug 2026-05-15 04:53:26 +00:00
c7f6a1dc4d fix: transpose B and SFB on the Python side at weight-load time, and adjust the SFB remap kernel to read from column-major source layout 2026-05-15 04:35:45 +00:00
c56cc34ae1 fix: LayoutBTag is now RowMajor 2026-05-15 04:30:27 +00:00
9975558c23 Add always-on alpha/x_sf debug prints for L1 and L2 GEMM calls 2026-05-15 03:59:07 +00:00
ff6bb32684 Plumb global scale as GEMM alpha instead of folding into UE4M3
stage_activation now returns (x_fp4, x_sf, input_global_scale).
The global scale is applied as the CUTLASS GEMM alpha parameter
in the epilogue: D = alpha * A @ B, avoiding the fp32→UE4M3
round-trip that folding would introduce.

Changes:
- stage_activation: returns global scale as 3rd value
- cutlass_nvfp4_gemm C++ binding: alpha param (was hardcoded 1.0)
- cutlass_grouped_nvfp4_gemm: passes alpha to per-expert GEMM
- nvfp4_mega_moe_l1/l2: accept alpha, pass to grouped GEMM
- nvfp4_moe_full: reads symm_buffer.input_global_scale for L1,
  uses stage_activation's returned global scale for L2
- SymmBuffer: added input_global_scale field
- vllm patch: stores global scale from stage_activation
2026-05-15 03:32:19 +00:00
d547da2948 stage_activation: add per-tensor global scale matching NVFP4 spec
Without a global scale, block scales (block_max / 6.0) could exceed
UE4M3 max (448.0) for large activations, causing saturation and garbage
MoE outputs. The degeneration pattern (positions 1-5 OK, then constant
spaces) is consistent with UE4M3 overflow: first few tokens have small
activations that fit, but once SiLU(mul(gate, up)) produces larger
values, block scales overflow and the GEMM produces zeros/garbage.

Fix: compute input_global_scale = amax / (6.0 * 448.0), normalize
before block quantization, then fold global scale back into block
scales (same as weight_transform.py folds weight_scale_2). This
ensures block scales are always ≤ 448.0 in UE4M3 range.
2026-05-15 03:27:47 +00:00
ce4c4b6fcb debug empty output 2026-05-14 22:13:32 +00:00
09d1307d78 damn clankers2 2026-05-14 20:34:51 +00:00
5bbe51357c damn clankers 2026-05-14 20:23:42 +00:00
57512d5f0d clean up 2026-05-14 19:20:08 +00:00
2687d1fc53 fix: convert global expert IDs to local before GEMM
vLLM's symm_buffer stores topk_ids as GLOBAL expert IDs (0..383).
Our weight tensors are indexed by LOCAL IDs (0..47 per rank).
Each rank r handles experts [r*48, r*48+47]. Without conversion,
topk_ids like 137, 222, 378 would index way out of bounds in the
weight tensor (shape (48, N, K)), producing garbage.

Derive experts_start_idx from the topk_ids and subtract to get
local IDs. This was why all ranks except rank 0 produced zero
expert matches → zero output → garbage text.
2026-05-14 17:43:58 +00:00
128ff84358 fix: 384 experts (not 256), clarify cross-rank reduce is in caller
DeepSeek-V4-Pro has 384 routed experts, 48 per rank (384/8).
The cross-rank all-reduce happens in the parent DeepseekV4MoE.forward,
not in our kernel. Our kernel writes local output; caller does reduce.
Fixed README, nvfp4_mega_moe.py comments.
2026-05-14 17:33:59 +00:00
1c15dadaa5 cleanup: remove dead _pack_ue4m3_to_uint32, fix data format docs
weight_transform.py returns float8_e4m3fn scales, NOT packed uint32.
The _pack_ue4m3_to_uint32 function was never called. Removed it.
Updated README data formats to accurately reflect the pipeline:
- Weight scales: float8_e4m3fn (direct to CUTLASS, no unpack)
- Activation scales: uint32 packed (from staging kernel, unpacked to float8)
2026-05-14 17:28:12 +00:00
1e0cea055c cleanup: remove all debug printfs from CUDA kernel and weight_transform
Removed printf from remap kernel (flat_rank dump, coordinate probes,
first-coord log). Removed weight_scale_2 debug prints from
weight_transform.py. Production-ready now.
2026-05-14 16:57:32 +00:00
839835cba4 fix: correct SF remap coordinate extraction for flat_rank=8
m = f0 + f1*32 + f2*128  (CuTe 'first sub varies fastest')
k_sf = f4 + f5*4
f3 is the Step<2> stride (degenerate, always=total), NOT a coordinate.
Previous formula (f3*2+f2)*128 was catastrophically wrong — mapped
everything to m=0 or m=huge.
2026-05-14 16:40:48 +00:00
1ef2fbc2fd debug: more indices for SF layout dump 2026-05-14 16:26:15 +00:00
c4b5b52a33 debug: single-thread SF layout dump at specific indices 2026-05-14 16:13:05 +00:00
17e6033ade debug: print specific indices for SF layout coordinate decomposition 2026-05-14 15:57:55 +00:00
8ee3f90e44 debug: handle flat_rank=8 for SF remap, add coordinate dump
Previous approach assumed rank 2-6, but actual rank is 8.
For R==8: 4 M sub-indices (inner_32, inner_4, tile_interleave, tile_m)
          4 K sub-indices (inner_16, inner_4_k, tile_k_interleave, tile_k)
m = (f3*2 + f2)*128 + f0*4 + f1
k_sf = f5 + f6*4  (tentative, needs printf verification)
Added printf of all 8 flat values for first 3 indices.
2026-05-14 15:45:52 +00:00
d2c1c76f5b debug: idx2crd+flatten approach with printf to determine flat_rank
Going back to the idx2crd approach which compiles and runs.
Added printf for flat_rank, MN, K_sf, and first coordinate extraction.
Handles ranks 2-6 with logical (m, k_sf) extraction.
This will tell us the actual flat_rank and whether our extraction is correct.
2026-05-14 15:34:46 +00:00
2ac3a7d631 fix: construct nested coordinate for CuTe layout shape ((32,4), K)
layout_sf(m, k_elem) with flat ints fails: Mismatched Ranks because
the layout shape is ((32,4), K_padded), not (M, K).
Decompose m into (inner_m, sub_m) = (m/4, m%4) to match the (32,4)
sub-shape, and pass as make_tuple(make_tuple(inner, sub), k_elem).
2026-05-14 15:32:12 +00:00
593ae998f8 fix: clean rewrite of cutlass_nvfp4_gemm.cu — no more file splicing
Removed dead code from old idx2crd approach. File is now clean:
- Source-iterating SF remap kernel with layout_sf(m, k_elem)
- Zero-init dest buffers before remap
- Proper extern C wrapping
2026-05-14 15:31:03 +00:00
196ee37fdb fix: rewrite SF remap kernel — source-iterating with layout_sf(m, k_elem)
Ripped out idx2crd + flatten + get<> approach entirely. New kernel
iterates over source indices (m, k_group) and uses layout_sf(m, k_elem)
to compute the CUTLASS destination offset. CuTe handles nested shape
decomposition internally — no rank inspection needed.

K coordinate is in element-space (k_group * SFVecSize) as the layout
expects. Iterates over groups (not every element) since all 16 elements
within a group share one SF byte — avoids 16x redundant writes.

Grid size based on source count (MN * K_sf), not dest buffer size.
2026-05-14 15:28:44 +00:00
fb390b24e2 debug: add printf to SF remap kernel to check flat_rank and layout shape 2026-05-14 15:24:18 +00:00
8f5322ca31 fix: add missing extern "C" opening brace lost during file reconstruction 2026-05-14 15:04:43 +00:00
a8bd962452 fix: SF remap — iterate dest indices, extract logical (m, k_sf) from nested coord
The forward-map approach (src -> layout_sf(m, k)) failed because CuTe's
layout operator requires coordinates matching the nested shape rank, and
passing flat (int, int) to a ((32,4),K) shape triggers Mismatched Ranks.

New approach: iterate over CUTLASS dest indices, use idx2crd to get the
hierarchical coordinate, flatten it, then extract logical (m, k_sf) by
interpreting the flattened sub-coordinates correctly:
  flat[0..2] = (inner_M, sub_M, tile_M) -> m = tile_M*128 + inner_M*4 + sub_M
  flat[3..5] = (inner_K, sub_K, tile_K) -> k_sf = tile_K*4 + sub_K
  (inner_K is within one SF group — same byte, so ignored for k_sf)

Previous bug: get<0> and get<1> of flatten gave (inner_M, sub_M) — both
M sub-indices. K information was never extracted, so only k_group=0 worked.

Dest buffer is zero-initialized so padding slots (where m >= MN or
k_sf >= K_sf) stay zero.
2026-05-14 15:01:47 +00:00
395cc31883 fix: use layout_sf(m, k_elem) instead of make_coord for nested shapes
make_coord(m, k_elem) produces rank-2 coord, but tile_to_shape creates
nested shapes like ((32,4,tiles_m), (16,4,tiles_k)) which expect
matching nested coords. layout_sf(m, k) operator handles hierarchical
projection automatically.
2026-05-14 14:57:13 +00:00