Commit Graph

33 Commits

Author SHA1 Message Date
79b9becf9c revert: don't use checkpoint input_scale for activation normalization
Using checkpoint input_scale as the normalization scale saturates
FP4 values (all block scales = 448). The input_scale is a calibration
constant, NOT the amax/(6*448) normalization scale.

Reverted to dynamic amax/(6*448) for activation quantization.
The correct use of checkpoint input_scale is still under investigation.

Preserved: _w13_input_scale and _w2_input_scale in finalize_weights
for future use once we understand the correct alpha contract.
2026-05-16 00:12:41 +00:00
a7eae10ef4 fix: use checkpoint input_scale for activation quantization
Critical fix: the checkpoint's input_scale was used during weight
calibration but we were computing dynamic scale from data (amax/2688).
This was 13x off from the checkpoint value.

Changes:
- stage_activation() accepts optional input_global_scale parameter
- nvfp4_mega_moe_full() accepts l1_input_scale and l2_input_scale
- vLLM patch preserves w13/w2_input_scale in finalize_weights
- L1 activation uses checkpoint w13_input_scale for quantization
- L2 activation uses checkpoint w2_input_scale for quantization
- alpha = input_scale * weight_scale_2 (correct calibration contract)
2026-05-15 23:57:08 +00:00
bb5a1ba4c8 cleanup: remove unused slot_token from nvfp4_moe_l2
L2 input is already slot-major, so slot_token was accepted but never
passed to the GEMM. Made it explicit by removing the parameter.
2026-05-15 23:50:39 +00:00
aa209ddd21 debug: add SF remap roundtrip verifier
Checks that forward remap wrote the correct bytes by comparing
src[mn*stride_mn + k_sf*stride_ksf] against dst[layout_sf(make_coord(mn, k_sf*16, 0))].
Prints error count for SFA and SFB on first GEMM call.
2026-05-15 22:59:44 +00:00
6626b75a2f fix: use filter_zeros for SF allocation + no-branch forward mapping
- Allocation: cute::size(cute::filter_zeros(layout)) matches CUTLASS examples
- Kernel: layout_sf(make_coord(mn, k_sf*16, 0)) — no branching on LayoutRank
- Avoids silent fallthrough that wrote dst[0] for all threads
2026-05-15 22:58:51 +00:00
6fc8fa61e0 fix: use flat logical coordinate layout_sf(make_coord(mn, k_elem, 0))
CuTe maps compatible flat coordinates into the natural hierarchical
coordinate before applying strides. No manual decomposition needed.
k_elem = k_sf * 16 (logical K element, not compact SF index).
2026-05-15 22:53:57 +00:00
a48717ccf5 fix: remove duplicate dst_idx declaration 2026-05-15 22:31:05 +00:00
5ff1b9e401 fix: use hierarchical coordinates for layout_sf forward mapping
Flat make_coord(mn, k*16) doesn't decompose into the nested atom shape.
Must manually decompose:
  mn -> (m0, m1, mt) where m0=mn%32, m1=(mn/32)%4, mt=mn/128
  k_sf -> (k0, k1, kt) where k0=0 (stride-0), k1=k_sf%4, kt=k_sf/4
2026-05-15 22:11:14 +00:00
a1fd4d6233 revert: back to layout_sf(make_coord(...)) — crd2idx was unnecessary 2026-05-15 21:55:00 +00:00
ea678ece64 fix: remove duplicate template declaration 2026-05-15 21:54:10 +00:00
59dad8e2fb fix: use crd2idx instead of layout operator() for SF forward mapping 2026-05-15 21:52:02 +00:00
a09d8e477e fix: remove static_assert in constexpr else (build fix) 2026-05-15 21:27:27 +00:00
7285331395 fix: replace col_major_src with explicit source strides
SFA: src_stride_mn=K_sf, src_stride_ksf=1 (row-major M, K_sf)
SFB: src_stride_mn=1, src_stride_ksf=N (row-major K_sf, N after transpose)

Removes ambiguity about physical memory layout. The source indexing
now uses mn*src_stride_mn + k_sf*src_stride_ksf which works for
any contiguous or transposed layout.
2026-05-15 21:23:21 +00:00
f6fd549800 fix: restore col_major_src handling for SFB source layout
SFB scales arrive as (K_sf, N) row-major after transpose+contiguous
in weight_transform.py. The col_major_src flag correctly describes
this. Don't assume both sources are (MN, K_sf).
2026-05-15 21:19:58 +00:00
63e67e1025 fix: rewrite SF remap as forward mapping (source→dst)
- Iterate over source indices (MN * K_sf) instead of dst indices
- Use layout_sf forward mapping: layout_sf(make_coord(mn, k_sf*16))
- No more idx2crd reverse extraction or stride-0 ambiguity
- Cleaner, less error-prone, blog-compatible
2026-05-15 20:51:30 +00:00
30b6c89424 fix: correct SF remap coordinate extraction
- First flattened group IS M/N (not K as previously assumed)
- mn = f0 + 32*f1 + 128*f2
- k_sf = f4 + 4*f5 (f3 is stride-0 inner K, ignored)
- The atom stride-0 dimension (f3) maps to offset 0, not a meaningful
  K sub-index. The actual k_sf comes from f4 (sub_k) + f5*4 (tile_k)
- Original code had group assignment right but k_sf extraction wrong
2026-05-15 20:44:46 +00:00
ff5a0843dc fix: divide K element index by SFVecSize to get k_sf
Based on veitner bearblog analysis of CUTLASS SF layout:
- Shape is ((32,4,K_tiles), (SFVecSize,4,M_tiles)) for SFA
- get<0..2> covers K dimension, get<3..5> covers M dimension
- k_sf = K_element_index / SFVecSize
2026-05-15 20:17:24 +00:00
a09b9b53a3 cleanup: remove printf and diag function from CUDA kernel (build fix) 2026-05-15 20:11:40 +00:00
deb6b3231a debug: swap M/K in SF remap + add printf diagnostics 2026-05-15 20:01:47 +00:00
c3841983a0 fix: SF remap uses cute::cosize() instead of cute::size()
The comment explicitly warned about this: allocation uses cosize (physical
size including tile padding) but the iteration bound used size (logical size).
This meant padding positions in the CUTLASS SF layout were never written,
leaving them as zero instead of their actual SF values. With uniform data
(all-ones), all SF values are the same so the bug was invisible. With
random data, different SF values are needed at different positions and
the missing writes corrupt the result.
2026-05-15 18:52:23 +00:00
773967452f debug: fix gs scalar conversion + add traceback 2026-05-15 18:27:44 +00:00
df916b87eb debug: fix gs.item() for multi-element tensor 2026-05-15 18:09:41 +00:00
755f9ad567 debug: fix per_expert_alpha ref + clean up BF16 reference scaling 2026-05-15 17:55:11 +00:00
de8acc7965 debug: dump raw GEMM inputs + first 8 output values 2026-05-15 17:02:40 +00:00
2fd55a94c6 fix: weight reshape bug + igs double-count in BF16 reference 2026-05-15 15:46:16 +00:00
c421a668f3 debug: BF16 reference GEMM + cosine comparison for L1 2026-05-15 14:16:24 +00:00
995589ac8a debug: add FP4 quantization round-trip diagnostic 2026-05-15 13:41:09 +00:00
d0ed3d84a8 debug: add L2, SiLU, and scatter pipeline prints 2026-05-15 13:21:25 +00:00
fd59222fc0 fix: stop folding global scale into float8 block scales
The fold block_sf (float8) * global_sf (float32) -> float8 loses ~25% precision.
Product of ~56-448 block_sf * ~4.65e-05 global_sf lands in float8 low-precision
zone where step size is 25%. This makes model output garbage despite finite values.

Fix: keep block scales as original float8, return global scales separately as
float32 per-expert vectors. Apply global scale as per-expert GEMM alpha in
cutlass_grouped_nvfp4_gemm (already iterates per-expert). For L1 with separate
gate/up global scales, use gate_gs as alpha and apply up_correction ratio to
the up half post-GEMM.

weight_transform.py: no more _fold_global_scale, returns (w, sf, global_sf)
nvfp4_mega_moe.py: per-expert alpha = activation_gs * weight_gs
kernel.py: per_expert_alpha parameter in grouped GEMM
deepseek_v4.py: updated type hints and comments
2026-05-15 12:42:53 +00:00
56e62e916d revert: idx2crd remap approach — source-first needs hierarchical coords
cute::crd2idx requires hierarchical coordinates matching the layout's
nested shape, which we don't have from flat (m, k_sf). Reverted to
idx2crd dest-first approach. The real bug was cute::size vs
cute::cosize for allocation, not the remap direction.
2026-05-15 11:44:38 +00:00
d5949a23b4 fix: use cute::crd2idx for SF remap — layout_sf() not directly callable
CuTe Layout objects with hierarchical shapes can't be called directly
with flat (m, k_sf). Use cute::crd2idx(make_coord(m, k_sf), layout_sf)
to convert logical coordinates to physical indices.
2026-05-15 11:39:57 +00:00
9908fd64d9 feat: CUTLASS NVFP4 mega_moe kernel — slot-based L1/L2, source-first SF remap
Major changes from initial TileLang prototype:

Kernel:
- CUTLASS NVFP4 block-scaled GEMM (SM100 Blackwell, OpClassBlockScaledTensorOp)
- Slot-based dispatch: L1 GEMM → SiLU+Mul per-slot → L2 GEMM → index_add scatter
- 1D slot_expert_ids passed to both L1 and L2 (no 2D topk_ids rebuild)
- slot_token gathered in cutlass_grouped_nvfp4_gemm when provided

SF Remap (source-first):
- Iterates logical (m, k_sf) source grid, uses layout_sf(make_coord(m, k_sf))
  for CUTLASS dest index — no idx2crd/flatten coordinate extraction
- 2D kernel launch: dim3 block(32,8), grid over (K_sf, MN)
- Uses cute::cosize() for physical allocation size (not cute::size)
- SFA: (MN, K_sf) row-major; SFB: (K_sf, MN) row-major (col-major)

Weight transform:
- UE4M3 unpack with bit reinterpret (not value cast)
- Global scale folding (weight_scale_2) for gate/up split
- clamp(0,448) → float8_e4m3fn, transpose (N,K)→(K,N) for CUTLASS

No prepack cache:
- SFB remapped per-call inside CUTLASS (~µs, not the bottleneck)
- See README for why prepack cache must never return (OOM, CUDA graphs,
  M-dependent layout, cross-layer collisions)

Stage activation:
- Nearest-neighbor E2M1 quantization (no clamp, no uniform steps)
- Per-tensor global scale → alpha for L2 GEMM

Bug fixes:
- _fold_global_scale: removed broken logical_widths branch
- unpack_ue4m3_u32: int32 for CUDA bitwise, view not to, ND support
- Correct expert param mapping for NVFP4 checkpoint
- SiLU applied per-slot (not after summing expert paths)
2026-05-15 11:38:18 +00:00
c2b752c2fe Initial: TileLang NVFP4 mega_moe kernel package
- nvfp4_mega_moe_full: drop-in replacement for deep_gemm.mega.fp8_nvfp4_mega_moe
- transform_nvfp4_weights_for_mega_moe: weight transformation (tested)
- SymmBuffer + get_symm_buffer_for_nvfp4_mega_moe: API-matching stubs
- MEGA_MOE_STATIC=1 support for pipeline testing
- pyproject.toml for pip install
2026-05-13 15:44:51 +00:00