Commit Graph

337 Commits

Author SHA1 Message Date
a1fd4d6233 revert: back to layout_sf(make_coord(...)) — crd2idx was unnecessary 2026-05-15 21:55:00 +00:00
ea678ece64 fix: remove duplicate template declaration 2026-05-15 21:54:10 +00:00
59dad8e2fb fix: use crd2idx instead of layout operator() for SF forward mapping 2026-05-15 21:52:02 +00:00
a09d8e477e fix: remove static_assert in constexpr else (build fix) 2026-05-15 21:27:27 +00:00
7285331395 fix: replace col_major_src with explicit source strides
SFA: src_stride_mn=K_sf, src_stride_ksf=1 (row-major M, K_sf)
SFB: src_stride_mn=1, src_stride_ksf=N (row-major K_sf, N after transpose)

Removes ambiguity about physical memory layout. The source indexing
now uses mn*src_stride_mn + k_sf*src_stride_ksf which works for
any contiguous or transposed layout.
2026-05-15 21:23:21 +00:00
f6fd549800 fix: restore col_major_src handling for SFB source layout
SFB scales arrive as (K_sf, N) row-major after transpose+contiguous
in weight_transform.py. The col_major_src flag correctly describes
this. Don't assume both sources are (MN, K_sf).
2026-05-15 21:19:58 +00:00
63e67e1025 fix: rewrite SF remap as forward mapping (source→dst)
- Iterate over source indices (MN * K_sf) instead of dst indices
- Use layout_sf forward mapping: layout_sf(make_coord(mn, k_sf*16))
- No more idx2crd reverse extraction or stride-0 ambiguity
- Cleaner, less error-prone, blog-compatible
2026-05-15 20:51:30 +00:00
30b6c89424 fix: correct SF remap coordinate extraction
- First flattened group IS M/N (not K as previously assumed)
- mn = f0 + 32*f1 + 128*f2
- k_sf = f4 + 4*f5 (f3 is stride-0 inner K, ignored)
- The atom stride-0 dimension (f3) maps to offset 0, not a meaningful
  K sub-index. The actual k_sf comes from f4 (sub_k) + f5*4 (tile_k)
- Original code had group assignment right but k_sf extraction wrong
2026-05-15 20:44:46 +00:00
ff5a0843dc fix: divide K element index by SFVecSize to get k_sf
Based on veitner bearblog analysis of CUTLASS SF layout:
- Shape is ((32,4,K_tiles), (SFVecSize,4,M_tiles)) for SFA
- get<0..2> covers K dimension, get<3..5> covers M dimension
- k_sf = K_element_index / SFVecSize
2026-05-15 20:17:24 +00:00
a09b9b53a3 cleanup: remove printf and diag function from CUDA kernel (build fix) 2026-05-15 20:11:40 +00:00
e7c3341317 docs: update DEBUG_LOG with M/K swap root cause 2026-05-15 20:03:20 +00:00
deb6b3231a debug: swap M/K in SF remap + add printf diagnostics 2026-05-15 20:01:47 +00:00
22f0457ccf test: isolate SFA vs SFB remap bug 2026-05-15 19:59:39 +00:00
9eaf6d07e8 test: quick random test 2026-05-15 19:58:57 +00:00
fa7b394571 docs: update DEBUG_LOG with root cause (size→cosize) and full debug timeline 2026-05-15 18:56:09 +00:00
c3841983a0 fix: SF remap uses cute::cosize() instead of cute::size()
The comment explicitly warned about this: allocation uses cosize (physical
size including tile padding) but the iteration bound used size (logical size).
This meant padding positions in the CUTLASS SF layout were never written,
leaving them as zero instead of their actual SF values. With uniform data
(all-ones), all SF values are the same so the bug was invisible. With
random data, different SF values are needed at different positions and
the missing writes corrupt the result.
2026-05-15 18:52:23 +00:00
67dcfa83f5 test: random data at small dims + alpha sweep 2026-05-15 18:51:52 +00:00
60f7f60818 test: ultra-minimal GEMM with all-ones 2026-05-15 18:51:31 +00:00
363dd893f0 test: dimension sweep to isolate GEMM bug 2026-05-15 18:51:09 +00:00
fee5a97ebb fix: cosine_similarity dim for M>0 2026-05-15 18:50:45 +00:00
f9330a1777 test: standalone M=1 GEMM test with deterministic data 2026-05-15 18:47:26 +00:00
1b63a46168 docs: update DEBUG_LOG with cosine≈0 finding + new hypotheses 2026-05-15 18:35:00 +00:00
773967452f debug: fix gs scalar conversion + add traceback 2026-05-15 18:27:44 +00:00
df916b87eb debug: fix gs.item() for multi-element tensor 2026-05-15 18:09:41 +00:00
755f9ad567 debug: fix per_expert_alpha ref + clean up BF16 reference scaling 2026-05-15 17:55:11 +00:00
de8acc7965 debug: dump raw GEMM inputs + first 8 output values 2026-05-15 17:02:40 +00:00
9159cb6bb3 docs: add debug log — current state, hypotheses, fixes 2026-05-15 15:48:57 +00:00
2fd55a94c6 fix: weight reshape bug + igs double-count in BF16 reference 2026-05-15 15:46:16 +00:00
c421a668f3 debug: BF16 reference GEMM + cosine comparison for L1 2026-05-15 14:16:24 +00:00
995589ac8a debug: add FP4 quantization round-trip diagnostic 2026-05-15 13:41:09 +00:00
d0ed3d84a8 debug: add L2, SiLU, and scatter pipeline prints 2026-05-15 13:21:25 +00:00
da5572f497 clean: remove diagnostic scripts from repo 2026-05-15 12:50:14 +00:00
fd59222fc0 fix: stop folding global scale into float8 block scales
The fold block_sf (float8) * global_sf (float32) -> float8 loses ~25% precision.
Product of ~56-448 block_sf * ~4.65e-05 global_sf lands in float8 low-precision
zone where step size is 25%. This makes model output garbage despite finite values.

Fix: keep block scales as original float8, return global scales separately as
float32 per-expert vectors. Apply global scale as per-expert GEMM alpha in
cutlass_grouped_nvfp4_gemm (already iterates per-expert). For L1 with separate
gate/up global scales, use gate_gs as alpha and apply up_correction ratio to
the up half post-GEMM.

weight_transform.py: no more _fold_global_scale, returns (w, sf, global_sf)
nvfp4_mega_moe.py: per-expert alpha = activation_gs * weight_gs
kernel.py: per_expert_alpha parameter in grouped GEMM
deepseek_v4.py: updated type hints and comments
2026-05-15 12:42:53 +00:00
56e62e916d revert: idx2crd remap approach — source-first needs hierarchical coords
cute::crd2idx requires hierarchical coordinates matching the layout's
nested shape, which we don't have from flat (m, k_sf). Reverted to
idx2crd dest-first approach. The real bug was cute::size vs
cute::cosize for allocation, not the remap direction.
2026-05-15 11:44:38 +00:00
d5949a23b4 fix: use cute::crd2idx for SF remap — layout_sf() not directly callable
CuTe Layout objects with hierarchical shapes can't be called directly
with flat (m, k_sf). Use cute::crd2idx(make_coord(m, k_sf), layout_sf)
to convert logical coordinates to physical indices.
2026-05-15 11:39:57 +00:00
9908fd64d9 feat: CUTLASS NVFP4 mega_moe kernel — slot-based L1/L2, source-first SF remap
Major changes from initial TileLang prototype:

Kernel:
- CUTLASS NVFP4 block-scaled GEMM (SM100 Blackwell, OpClassBlockScaledTensorOp)
- Slot-based dispatch: L1 GEMM → SiLU+Mul per-slot → L2 GEMM → index_add scatter
- 1D slot_expert_ids passed to both L1 and L2 (no 2D topk_ids rebuild)
- slot_token gathered in cutlass_grouped_nvfp4_gemm when provided

SF Remap (source-first):
- Iterates logical (m, k_sf) source grid, uses layout_sf(make_coord(m, k_sf))
  for CUTLASS dest index — no idx2crd/flatten coordinate extraction
- 2D kernel launch: dim3 block(32,8), grid over (K_sf, MN)
- Uses cute::cosize() for physical allocation size (not cute::size)
- SFA: (MN, K_sf) row-major; SFB: (K_sf, MN) row-major (col-major)

Weight transform:
- UE4M3 unpack with bit reinterpret (not value cast)
- Global scale folding (weight_scale_2) for gate/up split
- clamp(0,448) → float8_e4m3fn, transpose (N,K)→(K,N) for CUTLASS

No prepack cache:
- SFB remapped per-call inside CUTLASS (~µs, not the bottleneck)
- See README for why prepack cache must never return (OOM, CUDA graphs,
  M-dependent layout, cross-layer collisions)

Stage activation:
- Nearest-neighbor E2M1 quantization (no clamp, no uniform steps)
- Per-tensor global scale → alpha for L2 GEMM

Bug fixes:
- _fold_global_scale: removed broken logical_widths branch
- unpack_ue4m3_u32: int32 for CUDA bitwise, view not to, ND support
- Correct expert param mapping for NVFP4 checkpoint
- SiLU applied per-slot (not after summing expert paths)
2026-05-15 11:38:18 +00:00
c2b752c2fe Initial: TileLang NVFP4 mega_moe kernel package
- nvfp4_mega_moe_full: drop-in replacement for deep_gemm.mega.fp8_nvfp4_mega_moe
- transform_nvfp4_weights_for_mega_moe: weight transformation (tested)
- SymmBuffer + get_symm_buffer_for_nvfp4_mega_moe: API-matching stubs
- MEGA_MOE_STATIC=1 support for pipeline testing
- pyproject.toml for pip install
2026-05-13 15:44:51 +00:00