fix: transpose B and SFB on the Python side at weight-load time, and adjust the SFB remap kernel to read from column-major source layout

This commit is contained in:
2026-05-15 04:35:45 +00:00
parent c56cc34ae1
commit c7f6a1dc4d
5 changed files with 41 additions and 35 deletions

View File

@@ -38,7 +38,7 @@ using LayoutATag = cutlass::layout::RowMajor;
constexpr int AlignmentA = 32;
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
using LayoutBTag = cutlass::layout::RowMajor;
using LayoutBTag = cutlass::layout::ColumnMajor;
constexpr int AlignmentB = 32;
using ElementD = cutlass::bfloat16_t;
@@ -94,7 +94,7 @@ using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
/////////////////////////////////////////////////////////////////////////////////////////////////
// Scale factor remap: row-major source -> CUTLASS interleaved layout
// Scale factor remap: source (row-major or col-major) -> CUTLASS interleaved layout
//
// Iterates over CUTLASS dest indices, uses idx2crd to get the hierarchical coordinate,
// then extracts logical (m, k_sf) from the flattened result.
@@ -112,10 +112,11 @@ using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
template<typename LayoutSF>
__global__ void remap_sf_to_cutlass_kernel(
const cutlass::float_ue4m3_t* __restrict__ src, // (MN, K_sf) row-major
const cutlass::float_ue4m3_t* __restrict__ src, // (MN, K_sf) or (K_sf, MN) depending on col_major_src
cutlass::float_ue4m3_t* __restrict__ dst, // CUTLASS interleaved layout (zero-initialized)
LayoutSF layout_sf, // CuTe layout for dst
int MN, int K_sf // Source dimensions (in SF groups)
int MN, int K_sf, // Source dimensions (in SF groups)
bool col_major_src = false // true if source is (K_sf, MN) row-major = (MN, K_sf) col-major
) {
int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = cute::size(layout_sf);
@@ -129,14 +130,6 @@ __global__ void remap_sf_to_cutlass_kernel(
int m = 0, k_sf = 0;
if constexpr (R == 8) {
// 8 flattened coordinates from idx2crd:
// f0 = inner_m (0..31), f1 = sub_m (0..3), f2 = tile_m (0..)
// f3 = step_m stride (degenerate — always equals total, not a coordinate)
// f4 = sub_k (0..3), f5 = tile_k (0..), f6 = 0, f7 = 0
//
// CuTe "first sub varies fastest" for Shape<32, 4>:
// m = f0 + f1 * 32 + f2 * 128
// k_sf = f4 + f5 * 4
m = cute::get<0>(flat) + cute::get<1>(flat) * 32 + cute::get<2>(flat) * 128;
k_sf = cute::get<4>(flat) + cute::get<5>(flat) * 4;
} else {
@@ -144,7 +137,9 @@ __global__ void remap_sf_to_cutlass_kernel(
}
if (m < MN && k_sf < K_sf) {
dst[dst_idx] = src[m * K_sf + k_sf];
// SFA: source is (MN, K_sf) row-major → src[m * K_sf + k_sf]
// SFB: source is (K_sf, MN) row-major (col-major (MN, K_sf)) → src[k_sf * MN + m]
dst[dst_idx] = col_major_src ? src[k_sf * MN + m] : src[m * K_sf + k_sf];
}
}
@@ -186,9 +181,9 @@ int cutlass_nvfp4_gemm_run(
int block = 256;
remap_sf_to_cutlass_kernel<<<(sfa_size + block - 1) / block, block, 0, stream>>>(
static_cast<const ElementSF*>(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf);
static_cast<const ElementSF*>(SFA_ptr), sfa_cutlass.get(), layout_SFA, M, K_sf, false);
remap_sf_to_cutlass_kernel<<<(sfb_size + block - 1) / block, block, 0, stream>>>(
static_cast<const ElementSF*>(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf);
static_cast<const ElementSF*>(SFB_ptr), sfb_cutlass.get(), layout_SFB, N, K_sf, true);
typename Gemm::Arguments arguments {
cutlass::gemm::GemmUniversalMode::kGemm,

View File

@@ -20,8 +20,8 @@ except ImportError:
def cutlass_nvfp4_blockscaled_gemm(
A_packed, # (M, K_half) int8 packed E2M1
SFA, # scale factors for A (float8_e4m3fn)
B_packed, # (N, K_half) int8 packed E2M1
SFB, # scale factors for B (float8_e4m3fn)
B_packed, # (K_half, N) int8 packed E2M1, column-major for CUTLASS
SFB, # scale factors for B (sf_k, N) float8_e4m3fn, column-major for CUTLASS
M, N, K, # Problem dimensions (K in FP4 elements)
alpha=1.0, # fp32 scalar applied in epilogue: D = alpha * A @ B + beta * C
):
@@ -34,8 +34,8 @@ def cutlass_nvfp4_blockscaled_gemm(
def cutlass_grouped_nvfp4_gemm(
x_fp4, # (num_tokens, K_half) int8 packed E2M1
x_sf, # (num_tokens, sf_k) float8_e4m3fn block scales
weights, # (E_per_rank, N, K_half) int8 packed E2M1
weight_sf, # (E_per_rank, N, sf_k) float8_e4m3fn block scales
weights, # (E_per_rank, K_half, N) int8 packed E2M1, column-major for CUTLASS
weight_sf, # (E_per_rank, sf_k, N) float8_e4m3fn, column-major for CUTLASS
topk_ids, # (num_tokens, NUM_TOPK) int32
topk_weights, # (num_tokens, NUM_TOPK) float32
alpha=1.0, # fp32 scalar: D = alpha * A @ B (from stage_activation global scale)
@@ -48,7 +48,8 @@ def cutlass_grouped_nvfp4_gemm(
num_tokens = x_fp4.shape[0]
K_half = x_fp4.shape[1]
K = K_half * 2 # Actual K dimension (2 FP4 per byte)
N = weights.shape[1] # Output dimension
# Weights are (E, K_half, N) column-major (transposed at load time for CUTLASS ColumnMajor B)
N = weights.shape[2] # Output dimension
num_experts = weights.shape[0]
num_topk = topk_ids.shape[1]
@@ -69,8 +70,8 @@ def cutlass_grouped_nvfp4_gemm(
# Gather tokens for this expert
expert_x = x_fp4[token_indices] # (num_expert_tokens, K_half)
expert_x_sf = x_sf[token_indices] # (num_expert_tokens, sf_k)
expert_w = weights[e] # (N, K_half)
expert_w_sf = weight_sf[e] # (N, sf_k) — THIS IS SCALES, NOT WEIGHTS
expert_w = weights[e] # (K_half, N) column-major for CUTLASS
expert_w_sf = weight_sf[e] # (sf_k, N) column-major for CUTLASS
M_expert = token_indices.shape[0]

View File

@@ -40,10 +40,14 @@ def test_cutlass_nvfp4_gemm():
B_bf16 = unpack_e2m1_to_bf16(B_packed, B_scales)
C_ref = torch.matmul(A_bf16, B_bf16.t())
# CUTLASS expects B in column-major: (K_half, N) for weights, (sf_k, N) for scales
B_packed_cm = B_packed.t().contiguous()
B_scales_cm = B_scales.t().contiguous()
# CUTLASS native NVFP4 GEMM
try:
from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import cutlass_nvfp4_blockscaled_gemm
C_cutlass = cutlass_nvfp4_blockscaled_gemm(A_packed, A_scales, B_packed, B_scales)
C_cutlass = cutlass_nvfp4_blockscaled_gemm(A_packed, A_scales, B_packed_cm, B_scales_cm)
# Compare (NVFP4 has low precision, so use loose tolerance)
diff = (C_cutlass.float() - C_ref.float()).abs()
@@ -78,8 +82,9 @@ def test_grouped_gemm():
# Create inputs
x_packed = torch.randint(-128, 127, (num_tokens, K // 2), dtype=torch.int8, device=device)
x_scales = torch.randn(num_tokens, K // 16, dtype=torch.float8_e4m3fn, device=device).abs().clamp(min=0.0625, max=448.0)
weights = torch.randint(-128, 127, (E, N, K // 2), dtype=torch.int8, device=device)
weight_scales = torch.randn(E, N, K // 16, dtype=torch.float8_e4m3fn, device=device).abs().clamp(min=0.0625, max=448.0)
# Weights in column-major for CUTLASS: (E, K_half, N) and (E, sf_k, N)
weights = torch.randint(-128, 127, (E, K // 2, N), dtype=torch.int8, device=device)
weight_scales = torch.randn(E, K // 16, N, dtype=torch.float8_e4m3fn, device=device).abs().clamp(min=0.0625, max=448.0)
topk_ids = torch.randint(0, E, (num_tokens, top_k), dtype=torch.int32, device=device)
topk_weights = torch.rand(num_tokens, top_k, dtype=torch.float32, device=device)
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

View File

@@ -91,8 +91,8 @@ MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0"))
def nvfp4_mega_moe_l1(
x_fp4, # (num_tokens, K//2) int8 packed E2M1
x_sf, # (num_tokens, sf_k_groups) uint32 packed UE4M3
l1_weights, # (E_per_rank, 2*INTER, K//2) int8 K-major
l1_scales, # (E_per_rank, 2*INTER, sf_k_groups) uint32 packed UE4M3
l1_weights, # (E_per_rank, K//2, 2*INTER) int8, column-major for CUTLASS
l1_scales, # (E_per_rank, sf_k_groups, 2*INTER) float8_e4m3fn, column-major
topk_ids, # (num_tokens, NUM_TOPK) int32
topk_weights, # (num_tokens, NUM_TOPK) float32
num_experts_per_rank,
@@ -108,7 +108,7 @@ def nvfp4_mega_moe_l1(
num_tokens = x_fp4.shape[0]
K_half = x_fp4.shape[1]
K = K_half * 2 # HIDDEN = 7168
N = l1_weights.shape[1] # 2 * INTERMEDIATE = 6144
N = l1_weights.shape[2] # 2 * INTERMEDIATE = 6144 (column-major: shape is E, K_half, N)
if MEGA_MOE_DEBUG:
print(f"[nvfp4_moe_l1] tokens={num_tokens} K={K} N={N} "
@@ -130,8 +130,8 @@ def nvfp4_mega_moe_l1(
def nvfp4_mega_moe_l2(
x_fp4, # (num_tokens, INTER//2) int8 packed E2M1
x_sf, # (num_tokens, sf_k_groups) uint32 packed UE4M3
l2_weights, # (E_per_rank, HIDDEN, INTER//2) int8 K-major
l2_scales, # (E_per_rank, HIDDEN, sf_k_groups) uint32 packed UE4M3
l2_weights, # (E_per_rank, INTER//2, HIDDEN) int8, column-major for CUTLASS
l2_scales, # (E_per_rank, sf_k_groups, HIDDEN) float8_e4m3fn, column-major
topk_ids, # (num_tokens, NUM_TOPK) int32
topk_weights, # (num_tokens, NUM_TOPK) float32
num_experts_per_rank,
@@ -144,7 +144,7 @@ def nvfp4_mega_moe_l2(
num_tokens = x_fp4.shape[0]
K_half = x_fp4.shape[1]
K = K_half * 2 # INTERMEDIATE = 3072
N = l2_weights.shape[1] # HIDDEN = 7168
N = l2_weights.shape[2] # HIDDEN = 7168 (column-major: shape is E, K_half, N)
if MEGA_MOE_DEBUG:
print(f"[nvfp4_moe_l2] tokens={num_tokens} K={K} N={N} "

View File

@@ -97,9 +97,14 @@ def transform_nvfp4_weights_for_mega_moe(
l1_sf_out = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous()
l2_sf_out = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous()
# L1 weights: plain concat [gate; up] — no interleave needed
# (Our CUTLASS kernel uses 1x1x1 ClusterShape, not 2CTA)
l1_weight_out = l1_weight.contiguous()
l2_weight_out = l2_weight.contiguous()
# CUTLASS B is declared ColumnMajor — it expects (K, N) in memory.
# Checkpoint weights are (N, K_half) row-major, so we transpose to (K_half, N)
# which is column-major (N, K_half). This is a one-time cost at load time.
l1_weight_out = l1_weight.transpose(-2, -1).contiguous()
l2_weight_out = l2_weight.transpose(-2, -1).contiguous()
# Same for scale factors: (N, sf_k) row-major → (sf_k, N) column-major
l1_sf_out = l1_sf_out.transpose(-2, -1).contiguous()
l2_sf_out = l2_sf_out.transpose(-2, -1).contiguous()
return (l1_weight_out, l1_sf_out), (l2_weight_out, l2_sf_out)