Commit Graph

57 Commits

Author SHA1 Message Date
008f8cccbd docs: comprehensive README with SF remap probe data, bug history, coordinate table
Added detailed SF remap section with the empirical coordinate dump table
showing flat_rank=8 decomposition. Documented all 5 bugs found/fixed,
the diagnostic trail (constant-scale test, single-element probes), and
the 6 verification probes confirming the extraction formula.
2026-05-14 17:02:53 +00:00
1e0cea055c cleanup: remove all debug printfs from CUDA kernel and weight_transform
Removed printf from remap kernel (flat_rank dump, coordinate probes,
first-coord log). Removed weight_scale_2 debug prints from
weight_transform.py. Production-ready now.
2026-05-14 16:57:32 +00:00
839835cba4 fix: correct SF remap coordinate extraction for flat_rank=8
m = f0 + f1*32 + f2*128  (CuTe 'first sub varies fastest')
k_sf = f4 + f5*4
f3 is the Step<2> stride (degenerate, always=total), NOT a coordinate.
Previous formula (f3*2+f2)*128 was catastrophically wrong — mapped
everything to m=0 or m=huge.
2026-05-14 16:40:48 +00:00
1ef2fbc2fd debug: more indices for SF layout dump 2026-05-14 16:26:15 +00:00
c4b5b52a33 debug: single-thread SF layout dump at specific indices 2026-05-14 16:13:05 +00:00
17e6033ade debug: print specific indices for SF layout coordinate decomposition 2026-05-14 15:57:55 +00:00
8ee3f90e44 debug: handle flat_rank=8 for SF remap, add coordinate dump
Previous approach assumed rank 2-6, but actual rank is 8.
For R==8: 4 M sub-indices (inner_32, inner_4, tile_interleave, tile_m)
          4 K sub-indices (inner_16, inner_4_k, tile_k_interleave, tile_k)
m = (f3*2 + f2)*128 + f0*4 + f1
k_sf = f5 + f6*4  (tentative, needs printf verification)
Added printf of all 8 flat values for first 3 indices.
2026-05-14 15:45:52 +00:00
d2c1c76f5b debug: idx2crd+flatten approach with printf to determine flat_rank
Going back to the idx2crd approach which compiles and runs.
Added printf for flat_rank, MN, K_sf, and first coordinate extraction.
Handles ranks 2-6 with logical (m, k_sf) extraction.
This will tell us the actual flat_rank and whether our extraction is correct.
2026-05-14 15:34:46 +00:00
2ac3a7d631 fix: construct nested coordinate for CuTe layout shape ((32,4), K)
layout_sf(m, k_elem) with flat ints fails: Mismatched Ranks because
the layout shape is ((32,4), K_padded), not (M, K).
Decompose m into (inner_m, sub_m) = (m/4, m%4) to match the (32,4)
sub-shape, and pass as make_tuple(make_tuple(inner, sub), k_elem).
2026-05-14 15:32:12 +00:00
593ae998f8 fix: clean rewrite of cutlass_nvfp4_gemm.cu — no more file splicing
Removed dead code from old idx2crd approach. File is now clean:
- Source-iterating SF remap kernel with layout_sf(m, k_elem)
- Zero-init dest buffers before remap
- Proper extern C wrapping
2026-05-14 15:31:03 +00:00
196ee37fdb fix: rewrite SF remap kernel — source-iterating with layout_sf(m, k_elem)
Ripped out idx2crd + flatten + get<> approach entirely. New kernel
iterates over source indices (m, k_group) and uses layout_sf(m, k_elem)
to compute the CUTLASS destination offset. CuTe handles nested shape
decomposition internally — no rank inspection needed.

K coordinate is in element-space (k_group * SFVecSize) as the layout
expects. Iterates over groups (not every element) since all 16 elements
within a group share one SF byte — avoids 16x redundant writes.

Grid size based on source count (MN * K_sf), not dest buffer size.
2026-05-14 15:28:44 +00:00
fb390b24e2 debug: add printf to SF remap kernel to check flat_rank and layout shape 2026-05-14 15:24:18 +00:00
8f5322ca31 fix: add missing extern "C" opening brace lost during file reconstruction 2026-05-14 15:04:43 +00:00
a8bd962452 fix: SF remap — iterate dest indices, extract logical (m, k_sf) from nested coord
The forward-map approach (src -> layout_sf(m, k)) failed because CuTe's
layout operator requires coordinates matching the nested shape rank, and
passing flat (int, int) to a ((32,4),K) shape triggers Mismatched Ranks.

New approach: iterate over CUTLASS dest indices, use idx2crd to get the
hierarchical coordinate, flatten it, then extract logical (m, k_sf) by
interpreting the flattened sub-coordinates correctly:
  flat[0..2] = (inner_M, sub_M, tile_M) -> m = tile_M*128 + inner_M*4 + sub_M
  flat[3..5] = (inner_K, sub_K, tile_K) -> k_sf = tile_K*4 + sub_K
  (inner_K is within one SF group — same byte, so ignored for k_sf)

Previous bug: get<0> and get<1> of flatten gave (inner_M, sub_M) — both
M sub-indices. K information was never extracted, so only k_group=0 worked.

Dest buffer is zero-initialized so padding slots (where m >= MN or
k_sf >= K_sf) stay zero.
2026-05-14 15:01:47 +00:00
395cc31883 fix: use layout_sf(m, k_elem) instead of make_coord for nested shapes
make_coord(m, k_elem) produces rank-2 coord, but tile_to_shape creates
nested shapes like ((32,4,tiles_m), (16,4,tiles_k)) which expect
matching nested coords. layout_sf(m, k) operator handles hierarchical
projection automatically.
2026-05-14 14:57:13 +00:00
d90967d6e9 fix: SF remap — element-space K coords + zero-init dest buffer
Two fixes:
1. CuTe layout uses element-space K, not group-space. k_group=3 with
   SFVecSize=16 maps to k_elem=48 in the layout, not k=3.
   Added SFVecSize param to remap kernel, multiply k_sf * SFVecSize
   before passing to layout_sf().

2. Zero-init CUTLASS dest buffer before remap. The layout pads to
   tile boundaries (128x64), so dest is larger than M*K_sf. Unmapped
   padding slots reading garbage causes sporadic wrong results.
   Also fixed grid size to use source count (M*K_sf), not dest size.
2026-05-14 14:54:18 +00:00
5968ebad9f fix: SF remap was using idx2crd+flatten which gives atom sub-indices, not logical (m,k)
The remap kernel iterated over CUTLASS linear indices and tried to
reverse-map with idx2crd + flatten. But flatten() on the nested CuTe
coordinate (from tile_to_shape(SfAtom{}, ...)) gives atom-level
sub-indices, not logical (m, k). This caused all K-groups > 0 in SFA
to map to m*K_sf+0, losing K-group information entirely.

Proof: setting SFA[0,0]=2.0 changed row 0, but SFA[0,3]=2.0 produced
zero change. Only K-group 0 was being read.

Fix: iterate over SOURCE indices (row-major m, k) and use the CuTe
layout forward: layout_sf(make_coord(m, k)) -> CUTLASS dst index.
This is the correct forward direction that CuTe handles natively.

Constant-scale test (all SF=1.0) gave cosine=1.0, confirming the FP4
data path is correct. The bug was purely in the SF remap.
2026-05-14 14:51:02 +00:00
cf796e37cf debug: add weight_scale_2 shape/value logging in weight transform 2026-05-14 14:19:35 +00:00
879adc324d fix: _fold_global_scale — remove broken logical_widths branch
The logical_widths branch took expert 0 and 1's global scales and
applied them to ALL experts. For L1 with logical_widths=[3072,3072],
every expert got expert-0's scale on its gate half and expert-1's
scale on its up half. All other experts' global scales were discarded.

The else branch correctly broadcasts each expert's own (E,1) global
scale across (E, N, K//16). Removed the dead logical_widths code.
2026-05-14 14:17:44 +00:00
ef9cd023a9 fix: unpack_ue4m3_u32 — uint32 lacks CUDA bitwise ops, use int32
PyTorch doesn't implement bitwise_and/shift for UInt32 on CUDA.
Cast to int32 first, then extract bytes, then uint8 → view float8.
2026-05-14 13:44:42 +00:00
1c39e21d87 fix: remove broken L1 weight interleave
The interleave assumed gate/up were pre-interleaved in groups of 16
and that we needed 2CTA UMMA layout. Both wrong:
1. vLLM w13_weight is plain concat [gate; up] along output dim
2. Our CUTLASS kernel uses ClusterShape 1x1x1, not 2CTA

The interleave was shuffling weights into nonsense, making L1 GEMM
compute the wrong thing, and chunk(2) would split wrong halves.
2026-05-14 13:05:45 +00:00
80495c0cd6 docs: clarify SF layout remap is in CUDA, not sf_layout.py
sf_layout.py was a no-op (return sf) but the actual remap happens
in remap_sf_to_cutlass_kernel in cutlass_nvfp4_gemm.cu. Updated
sf_layout.py to pure reference docs so nobody gets confused again.
2026-05-14 13:04:31 +00:00
16f91ff0e1 fix: rewrite stage_activation with proper E2M1 quantization
Three bugs fixed:
1. clamp(0,15) was destroying sign bits — E2M1 is sign-magnitude 4-bit
   nibbles, not unsigned. Half the activation was zeroed.
2. Scale stored block_max but divided by block_max/6, so stored scale was
   6× too large. Now correctly stores block_max/6 (the actual dequant factor).
3. Uniform 0.5 step doesn't match E2M1 values {0,0.5,1,1.5,2,3,4,6}.
   Now snaps to nearest E2M1 representable magnitude.

New _quantize_to_e2m1 helper handles all three correctly:
- Sign-magnitude 4-bit nibble packing (bit3=sign, bits2:0=mag index)
- Correct block scale (block_max / 6.0)
- Nearest-neighbor to actual E2M1 values
2026-05-14 13:02:10 +00:00
3bcc0ac057 fix: unpack_ue4m3_u32 was value-casting instead of bit-reinterpreting
Byte 0x3F was becoming float8(63.0) instead of the float8 whose bit
pattern IS 0x3F (~0.984). Pack uses .view() (correct), unpack used
.to() (wrong) — they were not inverses. This corrupted every activation
scale fed to the L1 GEMM while weight scales were fine.
2026-05-14 12:59:20 +00:00
8b7fa0c91e add README: pipeline diagram, file map, data formats, known issues 2026-05-14 12:48:08 +00:00
d3f35c9465 cleanup: remove abandoned TileLang and Mojo files
- Deleted: layout.mojo, mega_moe.mojo, quantize.mojo (Mojo attempt)
- Deleted: nvfp4_blockscaled_gemm.py, staging.py, nvfp4_mega_moe.py (TileLang top-level)
- Deleted: tilelang_nvfp4_gemm.py, tilelang_kernels.py, nvfp4_dequant.py (TileLang package)
- Deleted: src/weight_transform.py (duplicate of package version)
- Fixed nvfp4_mega_moe.py: inlined unpack_ue4m3_u32, removed TileLang fallback imports
- Fixed weight_transform.py: renamed function, removed TileLang alias, updated docs
- Fixed __init__.py: removed TileLang alias, updated docstring
- CUTLASS is the only kernel path now
2026-05-14 12:44:47 +00:00
802c4ee12c Revert stage_activation to simple quantize (staging kernel API incompatible with L1 output dims) 2026-05-14 12:14:01 +00:00
69e0174792 Fix stage_activation: use Triton staging kernel instead of broken simple quantize 2026-05-14 12:01:34 +00:00
c016e66e23 Add CUDA sync + NaN/Inf check after each expert GEMM in grouped kernel 2026-05-14 11:27:58 +00:00
1dfe5ffd05 Add comprehensive README documenting quirks, pitfalls, and setup 2026-05-14 11:23:32 +00:00
904fc37ad8 Fix: use idx2crd instead of get_coord for CuTe layout coordinate lookup 2026-05-14 10:50:26 +00:00
494d30b6ab Fix: use CuTe get_coord for proper scale factor remap to CUTLASS interleaved layout 2026-05-14 10:48:58 +00:00
869151d211 Fix kernel.py: remove broken expand on scale factors (was expanding sf to weight size) 2026-05-14 10:36:16 +00:00
84becfac93 Test: pass scales directly to CUTLASS (no remap) to diagnose layout issue 2026-05-14 10:23:02 +00:00
a272bc49b0 Fix: torch::kBFloat16 2026-05-14 10:21:10 +00:00
3f62e49e6e Fix PyTorch API: use c10::cuda and at::kBF16 2026-05-14 10:20:00 +00:00
2ee4e26772 Fix: remove compile-time SM100 guard from pytorch binding, use runtime check instead 2026-05-14 10:18:36 +00:00
540e68593f Add scale factor remap kernel: remap simple row-major SFs to CUTLASS interleaved layout 2026-05-14 10:05:38 +00:00
2998c889e7 Implement simple FP4 quantization for L1→L2 re-quant step (no vLLM fp4_utils dependency) 2026-05-14 09:50:52 +00:00
98913c9b1a Fix stage_activation: use Triton staging kernel from vLLM patch instead of fp4_utils 2026-05-14 09:38:50 +00:00
25cbc85afe Replace kernel.py with thin wrapper around pre-compiled _C extension 2026-05-14 09:25:56 +00:00
33e5d67326 Add CUTLASS_CHECK macro 2026-05-13 23:28:03 +00:00
b7c5cba407 Fix device_memory include path 2026-05-13 23:27:06 +00:00
3299d22ad6 Fix type casts and includes for CUTLASS NVFP4 GEMM 2026-05-13 23:26:18 +00:00
1eb9c43217 Rewrite CUTLASS kernel based on NVIDIA example 72b (nv_float4_t, CollectiveBuilder, OpClassBlockScaledTensorOp) 2026-05-13 23:25:20 +00:00
8a9af441dc Fix includes: use cutlass/float_subbyte.h (has float_e2m1_t and float_ue4m3_t), point to latest CUTLASS 2026-05-13 23:23:01 +00:00
d789f5e3e0 Add CCCL include path for CUTLASS 3.x 2026-05-13 23:18:26 +00:00
12588047fd Fix setup.py: use include_dirs and extra_compile_args (correct PyTorch extension API) 2026-05-13 23:17:30 +00:00
1b1c3a42fe Fix setup.py source paths 2026-05-13 23:14:05 +00:00
f375c80bfe feat: CUTLASS NVFP4 block-scaled GEMM kernel (native SM100 Blackwell)
- Native NVFP4 block-scaled MMA using CUTLASS MainloopSm100TmaUmmaWarpSpecializedBlockScaled
- Invokes mxf8f6f4.block_scale tensor core instructions (tcgen05.mma)
- E2M1 (packed int8) + UE4M3 (float8_e4m3fn) block-16 scales → BF16 output
- No dequantization: hardware block-scaled MMA avoids costly dequantize+BF16 path
- PyTorch CUDA extension with CollectiveBuilder auto-deduction
- Grouped expert GEMM for MoE dispatch (32 experts/rank, top-6 routing)
- Integrated into nvfp4_mega_moe.py as primary path with TileLang fallback
- Standalone C API (cutlass_nvfp4_gemm.cu) for direct B200 compilation
- Build script, setup.py, and test script for B200 deployment

Files:
  cutlass_nvfp4_gemm/ — Kernel source, PyTorch binding, build/test scripts
  nvfp4_mega_moe.py — Updated to use CUTLASS kernel when available
2026-05-13 23:11:15 +00:00