- Split provided_slot_token vs slot_token_out (returned to caller)
- No gather when slot_token=None (L2 path), no unnecessary alloc
- .contiguous() on gathered tensors for CUTLASS alignment
- Return slot_token_out consistently
Old cache used only tag ('l1'/'l2'), so layer 1 would reuse layer 0's
packed scales if the function object persisted. Now keyed by
(tag, data_ptr, shape, dtype, device, N, K) — safe across layers.
Shape-based check (x_fp4.shape[0] != num_slots) silently fails when
num_tokens == num_slots in L1 (topk=1). Now checks if slot_token is
the identity mapping — only gathers when slot ordering differs from
token ordering.
Both L1 and L2 now pass pre-built 1D slot_expert_ids and slot_token to
cutlass_grouped_nvfp4_gemm instead of the 2D topk_ids.
The 2D path was broken for expert parallelism — local_mask matched ALL
local experts, producing mismatched slot_token/slot_k lengths that caused
vectorized_gather_kernel index out of bounds.
cutlass_grouped_nvfp4_gemm now:
- Takes 1D slot_expert_ids + optional slot_token
- Gathers x_fp4 by slot_token when needed (L1: tokens→slots)
- Skips gather when x_fp4 already has num_slots rows (L2)
The L2 function was rebuilding slot_expert_ids by scanning topk_ids with a
local_mask. This produced mismatched slot_k (all-expert mask) vs slot_token
(rank-local mask), causing vectorized_gather_kernel index out of bounds.
Now slot_expert_local is passed directly from the outer routing logic, matching
the same slot ordering as L1.
stage_activation now returns (x_fp4, x_sf, input_global_scale).
The global scale is applied as the CUTLASS GEMM alpha parameter
in the epilogue: D = alpha * A @ B, avoiding the fp32→UE4M3
round-trip that folding would introduce.
Changes:
- stage_activation: returns global scale as 3rd value
- cutlass_nvfp4_gemm C++ binding: alpha param (was hardcoded 1.0)
- cutlass_grouped_nvfp4_gemm: passes alpha to per-expert GEMM
- nvfp4_mega_moe_l1/l2: accept alpha, pass to grouped GEMM
- nvfp4_moe_full: reads symm_buffer.input_global_scale for L1,
uses stage_activation's returned global scale for L2
- SymmBuffer: added input_global_scale field
- vllm patch: stores global scale from stage_activation
Without a global scale, block scales (block_max / 6.0) could exceed
UE4M3 max (448.0) for large activations, causing saturation and garbage
MoE outputs. The degeneration pattern (positions 1-5 OK, then constant
spaces) is consistent with UE4M3 overflow: first few tokens have small
activations that fit, but once SiLU(mul(gate, up)) produces larger
values, block scales overflow and the GEMM produces zeros/garbage.
Fix: compute input_global_scale = amax / (6.0 * 448.0), normalize
before block quantization, then fold global scale back into block
scales (same as weight_transform.py folds weight_scale_2). This
ensures block scales are always ≤ 448.0 in UE4M3 range.
vLLM's symm_buffer stores topk_ids as GLOBAL expert IDs (0..383).
Our weight tensors are indexed by LOCAL IDs (0..47 per rank).
Each rank r handles experts [r*48, r*48+47]. Without conversion,
topk_ids like 137, 222, 378 would index way out of bounds in the
weight tensor (shape (48, N, K)), producing garbage.
Derive experts_start_idx from the topk_ids and subtract to get
local IDs. This was why all ranks except rank 0 produced zero
expert matches → zero output → garbage text.
DeepSeek-V4-Pro has 384 routed experts, 48 per rank (384/8).
The cross-rank all-reduce happens in the parent DeepseekV4MoE.forward,
not in our kernel. Our kernel writes local output; caller does reduce.
Fixed README, nvfp4_mega_moe.py comments.
weight_transform.py returns float8_e4m3fn scales, NOT packed uint32.
The _pack_ue4m3_to_uint32 function was never called. Removed it.
Updated README data formats to accurately reflect the pipeline:
- Weight scales: float8_e4m3fn (direct to CUTLASS, no unpack)
- Activation scales: uint32 packed (from staging kernel, unpacked to float8)
m = f0 + f1*32 + f2*128 (CuTe 'first sub varies fastest')
k_sf = f4 + f5*4
f3 is the Step<2> stride (degenerate, always=total), NOT a coordinate.
Previous formula (f3*2+f2)*128 was catastrophically wrong — mapped
everything to m=0 or m=huge.
Previous approach assumed rank 2-6, but actual rank is 8.
For R==8: 4 M sub-indices (inner_32, inner_4, tile_interleave, tile_m)
4 K sub-indices (inner_16, inner_4_k, tile_k_interleave, tile_k)
m = (f3*2 + f2)*128 + f0*4 + f1
k_sf = f5 + f6*4 (tentative, needs printf verification)
Added printf of all 8 flat values for first 3 indices.
Going back to the idx2crd approach which compiles and runs.
Added printf for flat_rank, MN, K_sf, and first coordinate extraction.
Handles ranks 2-6 with logical (m, k_sf) extraction.
This will tell us the actual flat_rank and whether our extraction is correct.
layout_sf(m, k_elem) with flat ints fails: Mismatched Ranks because
the layout shape is ((32,4), K_padded), not (M, K).
Decompose m into (inner_m, sub_m) = (m/4, m%4) to match the (32,4)
sub-shape, and pass as make_tuple(make_tuple(inner, sub), k_elem).
Removed dead code from old idx2crd approach. File is now clean:
- Source-iterating SF remap kernel with layout_sf(m, k_elem)
- Zero-init dest buffers before remap
- Proper extern C wrapping
Ripped out idx2crd + flatten + get<> approach entirely. New kernel
iterates over source indices (m, k_group) and uses layout_sf(m, k_elem)
to compute the CUTLASS destination offset. CuTe handles nested shape
decomposition internally — no rank inspection needed.
K coordinate is in element-space (k_group * SFVecSize) as the layout
expects. Iterates over groups (not every element) since all 16 elements
within a group share one SF byte — avoids 16x redundant writes.
Grid size based on source count (MN * K_sf), not dest buffer size.
The forward-map approach (src -> layout_sf(m, k)) failed because CuTe's
layout operator requires coordinates matching the nested shape rank, and
passing flat (int, int) to a ((32,4),K) shape triggers Mismatched Ranks.
New approach: iterate over CUTLASS dest indices, use idx2crd to get the
hierarchical coordinate, flatten it, then extract logical (m, k_sf) by
interpreting the flattened sub-coordinates correctly:
flat[0..2] = (inner_M, sub_M, tile_M) -> m = tile_M*128 + inner_M*4 + sub_M
flat[3..5] = (inner_K, sub_K, tile_K) -> k_sf = tile_K*4 + sub_K
(inner_K is within one SF group — same byte, so ignored for k_sf)
Previous bug: get<0> and get<1> of flatten gave (inner_M, sub_M) — both
M sub-indices. K information was never extracted, so only k_group=0 worked.
Dest buffer is zero-initialized so padding slots (where m >= MN or
k_sf >= K_sf) stay zero.