From c7f6a1dc4d4e5f9deadfc4e017cdb6dcf64cd439 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 15 May 2026 04:35:45 +0000 Subject: [PATCH] fix: transpose B and SFB on the Python side at weight-load time, and adjust the SFB remap kernel to read from column-major source layout --- .../cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu | 25 ++++++++----------- .../cutlass_nvfp4_gemm/kernel.py | 15 +++++------ .../cutlass_nvfp4_gemm/test_gemm.py | 11 +++++--- src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py | 12 ++++----- src/nvfp4_megamoe_kernel/weight_transform.py | 13 +++++++--- 5 files changed, 41 insertions(+), 35 deletions(-) diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu index 88bfd550..04b262b4 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu @@ -38,7 +38,7 @@ using LayoutATag = cutlass::layout::RowMajor; constexpr int AlignmentA = 32; using ElementB = cutlass::nv_float4_t; -using LayoutBTag = cutlass::layout::RowMajor; +using LayoutBTag = cutlass::layout::ColumnMajor; constexpr int AlignmentB = 32; using ElementD = cutlass::bfloat16_t; @@ -94,7 +94,7 @@ using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; ///////////////////////////////////////////////////////////////////////////////////////////////// -// Scale factor remap: row-major source -> CUTLASS interleaved layout +// Scale factor remap: source (row-major or col-major) -> CUTLASS interleaved layout // // Iterates over CUTLASS dest indices, uses idx2crd to get the hierarchical coordinate, // then extracts logical (m, k_sf) from the flattened result. @@ -112,10 +112,11 @@ using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; template __global__ void remap_sf_to_cutlass_kernel( - const cutlass::float_ue4m3_t* __restrict__ src, // (MN, K_sf) row-major + const cutlass::float_ue4m3_t* __restrict__ src, // (MN, K_sf) or (K_sf, MN) depending on col_major_src cutlass::float_ue4m3_t* __restrict__ dst, // CUTLASS interleaved layout (zero-initialized) LayoutSF layout_sf, // CuTe layout for dst - int MN, int K_sf // Source dimensions (in SF groups) + int MN, int K_sf, // Source dimensions (in SF groups) + bool col_major_src = false // true if source is (K_sf, MN) row-major = (MN, K_sf) col-major ) { int dst_idx = blockIdx.x * blockDim.x + threadIdx.x; int total = cute::size(layout_sf); @@ -129,14 +130,6 @@ __global__ void remap_sf_to_cutlass_kernel( int m = 0, k_sf = 0; if constexpr (R == 8) { - // 8 flattened coordinates from idx2crd: - // f0 = inner_m (0..31), f1 = sub_m (0..3), f2 = tile_m (0..) - // f3 = step_m stride (degenerate — always equals total, not a coordinate) - // f4 = sub_k (0..3), f5 = tile_k (0..), f6 = 0, f7 = 0 - // - // CuTe "first sub varies fastest" for Shape<32, 4>: - // m = f0 + f1 * 32 + f2 * 128 - // k_sf = f4 + f5 * 4 m = cute::get<0>(flat) + cute::get<1>(flat) * 32 + cute::get<2>(flat) * 128; k_sf = cute::get<4>(flat) + cute::get<5>(flat) * 4; } else { @@ -144,7 +137,9 @@ __global__ void remap_sf_to_cutlass_kernel( } if (m < MN && k_sf < K_sf) { - dst[dst_idx] = src[m * K_sf + k_sf]; + // SFA: source is (MN, K_sf) row-major → src[m * K_sf + k_sf] + // SFB: source is (K_sf, MN) row-major (col-major (MN, K_sf)) → src[k_sf * MN + m] + dst[dst_idx] = col_major_src ? src[k_sf * MN + m] : src[m * K_sf + k_sf]; } } @@ -186,9 +181,9 @@ int cutlass_nvfp4_gemm_run( int block = 256; remap_sf_to_cutlass_kernel<<<(sfa_size + block - 1) / block, block, 0, stream>>>( - static_cast(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf); + static_cast(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf, false); remap_sf_to_cutlass_kernel<<<(sfb_size + block - 1) / block, block, 0, stream>>>( - static_cast(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf); + static_cast(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf, true); typename Gemm::Arguments arguments { cutlass::gemm::GemmUniversalMode::kGemm, diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py index f43ccae9..e4462735 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py @@ -20,8 +20,8 @@ except ImportError: def cutlass_nvfp4_blockscaled_gemm( A_packed, # (M, K_half) int8 packed E2M1 SFA, # scale factors for A (float8_e4m3fn) - B_packed, # (N, K_half) int8 packed E2M1 - SFB, # scale factors for B (float8_e4m3fn) + B_packed, # (K_half, N) int8 packed E2M1, column-major for CUTLASS + SFB, # scale factors for B (sf_k, N) float8_e4m3fn, column-major for CUTLASS M, N, K, # Problem dimensions (K in FP4 elements) alpha=1.0, # fp32 scalar applied in epilogue: D = alpha * A @ B + beta * C ): @@ -34,8 +34,8 @@ def cutlass_nvfp4_blockscaled_gemm( def cutlass_grouped_nvfp4_gemm( x_fp4, # (num_tokens, K_half) int8 packed E2M1 x_sf, # (num_tokens, sf_k) float8_e4m3fn block scales - weights, # (E_per_rank, N, K_half) int8 packed E2M1 - weight_sf, # (E_per_rank, N, sf_k) float8_e4m3fn block scales + weights, # (E_per_rank, K_half, N) int8 packed E2M1, column-major for CUTLASS + weight_sf, # (E_per_rank, sf_k, N) float8_e4m3fn, column-major for CUTLASS topk_ids, # (num_tokens, NUM_TOPK) int32 topk_weights, # (num_tokens, NUM_TOPK) float32 alpha=1.0, # fp32 scalar: D = alpha * A @ B (from stage_activation global scale) @@ -48,7 +48,8 @@ def cutlass_grouped_nvfp4_gemm( num_tokens = x_fp4.shape[0] K_half = x_fp4.shape[1] K = K_half * 2 # Actual K dimension (2 FP4 per byte) - N = weights.shape[1] # Output dimension + # Weights are (E, K_half, N) column-major (transposed at load time for CUTLASS ColumnMajor B) + N = weights.shape[2] # Output dimension num_experts = weights.shape[0] num_topk = topk_ids.shape[1] @@ -69,8 +70,8 @@ def cutlass_grouped_nvfp4_gemm( # Gather tokens for this expert expert_x = x_fp4[token_indices] # (num_expert_tokens, K_half) expert_x_sf = x_sf[token_indices] # (num_expert_tokens, sf_k) - expert_w = weights[e] # (N, K_half) - expert_w_sf = weight_sf[e] # (N, sf_k) — THIS IS SCALES, NOT WEIGHTS + expert_w = weights[e] # (K_half, N) column-major for CUTLASS + expert_w_sf = weight_sf[e] # (sf_k, N) column-major for CUTLASS M_expert = token_indices.shape[0] diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/test_gemm.py b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/test_gemm.py index 692c56c4..fcfe2a1b 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/test_gemm.py +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/test_gemm.py @@ -40,10 +40,14 @@ def test_cutlass_nvfp4_gemm(): B_bf16 = unpack_e2m1_to_bf16(B_packed, B_scales) C_ref = torch.matmul(A_bf16, B_bf16.t()) + # CUTLASS expects B in column-major: (K_half, N) for weights, (sf_k, N) for scales + B_packed_cm = B_packed.t().contiguous() + B_scales_cm = B_scales.t().contiguous() + # CUTLASS native NVFP4 GEMM try: from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import cutlass_nvfp4_blockscaled_gemm - C_cutlass = cutlass_nvfp4_blockscaled_gemm(A_packed, A_scales, B_packed, B_scales) + C_cutlass = cutlass_nvfp4_blockscaled_gemm(A_packed, A_scales, B_packed_cm, B_scales_cm) # Compare (NVFP4 has low precision, so use loose tolerance) diff = (C_cutlass.float() - C_ref.float()).abs() @@ -78,8 +82,9 @@ def test_grouped_gemm(): # Create inputs x_packed = torch.randint(-128, 127, (num_tokens, K // 2), dtype=torch.int8, device=device) x_scales = torch.randn(num_tokens, K // 16, dtype=torch.float8_e4m3fn, device=device).abs().clamp(min=0.0625, max=448.0) - weights = torch.randint(-128, 127, (E, N, K // 2), dtype=torch.int8, device=device) - weight_scales = torch.randn(E, N, K // 16, dtype=torch.float8_e4m3fn, device=device).abs().clamp(min=0.0625, max=448.0) + # Weights in column-major for CUTLASS: (E, K_half, N) and (E, sf_k, N) + weights = torch.randint(-128, 127, (E, K // 2, N), dtype=torch.int8, device=device) + weight_scales = torch.randn(E, K // 16, N, dtype=torch.float8_e4m3fn, device=device).abs().clamp(min=0.0625, max=448.0) topk_ids = torch.randint(0, E, (num_tokens, top_k), dtype=torch.int32, device=device) topk_weights = torch.rand(num_tokens, top_k, dtype=torch.float32, device=device) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index 5d7ff937..ae5faafd 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -91,8 +91,8 @@ MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0")) def nvfp4_mega_moe_l1( x_fp4, # (num_tokens, K//2) int8 packed E2M1 x_sf, # (num_tokens, sf_k_groups) uint32 packed UE4M3 - l1_weights, # (E_per_rank, 2*INTER, K//2) int8 K-major - l1_scales, # (E_per_rank, 2*INTER, sf_k_groups) uint32 packed UE4M3 + l1_weights, # (E_per_rank, K//2, 2*INTER) int8, column-major for CUTLASS + l1_scales, # (E_per_rank, sf_k_groups, 2*INTER) float8_e4m3fn, column-major topk_ids, # (num_tokens, NUM_TOPK) int32 topk_weights, # (num_tokens, NUM_TOPK) float32 num_experts_per_rank, @@ -108,7 +108,7 @@ def nvfp4_mega_moe_l1( num_tokens = x_fp4.shape[0] K_half = x_fp4.shape[1] K = K_half * 2 # HIDDEN = 7168 - N = l1_weights.shape[1] # 2 * INTERMEDIATE = 6144 + N = l1_weights.shape[2] # 2 * INTERMEDIATE = 6144 (column-major: shape is E, K_half, N) if MEGA_MOE_DEBUG: print(f"[nvfp4_moe_l1] tokens={num_tokens} K={K} N={N} " @@ -130,8 +130,8 @@ def nvfp4_mega_moe_l1( def nvfp4_mega_moe_l2( x_fp4, # (num_tokens, INTER//2) int8 packed E2M1 x_sf, # (num_tokens, sf_k_groups) uint32 packed UE4M3 - l2_weights, # (E_per_rank, HIDDEN, INTER//2) int8 K-major - l2_scales, # (E_per_rank, HIDDEN, sf_k_groups) uint32 packed UE4M3 + l2_weights, # (E_per_rank, INTER//2, HIDDEN) int8, column-major for CUTLASS + l2_scales, # (E_per_rank, sf_k_groups, HIDDEN) float8_e4m3fn, column-major topk_ids, # (num_tokens, NUM_TOPK) int32 topk_weights, # (num_tokens, NUM_TOPK) float32 num_experts_per_rank, @@ -144,7 +144,7 @@ def nvfp4_mega_moe_l2( num_tokens = x_fp4.shape[0] K_half = x_fp4.shape[1] K = K_half * 2 # INTERMEDIATE = 3072 - N = l2_weights.shape[1] # HIDDEN = 7168 + N = l2_weights.shape[2] # HIDDEN = 7168 (column-major: shape is E, K_half, N) if MEGA_MOE_DEBUG: print(f"[nvfp4_moe_l2] tokens={num_tokens} K={K} N={N} " diff --git a/src/nvfp4_megamoe_kernel/weight_transform.py b/src/nvfp4_megamoe_kernel/weight_transform.py index 48c0c1f9..0e01033e 100644 --- a/src/nvfp4_megamoe_kernel/weight_transform.py +++ b/src/nvfp4_megamoe_kernel/weight_transform.py @@ -97,9 +97,14 @@ def transform_nvfp4_weights_for_mega_moe( l1_sf_out = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous() l2_sf_out = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous() - # L1 weights: plain concat [gate; up] — no interleave needed - # (Our CUTLASS kernel uses 1x1x1 ClusterShape, not 2CTA) - l1_weight_out = l1_weight.contiguous() - l2_weight_out = l2_weight.contiguous() + # CUTLASS B is declared ColumnMajor — it expects (K, N) in memory. + # Checkpoint weights are (N, K_half) row-major, so we transpose to (K_half, N) + # which is column-major (N, K_half). This is a one-time cost at load time. + l1_weight_out = l1_weight.transpose(-2, -1).contiguous() + l2_weight_out = l2_weight.transpose(-2, -1).contiguous() + + # Same for scale factors: (N, sf_k) row-major → (sf_k, N) column-major + l1_sf_out = l1_sf_out.transpose(-2, -1).contiguous() + l2_sf_out = l2_sf_out.transpose(-2, -1).contiguous() return (l1_weight_out, l1_sf_out), (l2_weight_out, l2_sf_out)