fix: L1 weight N dimension is 2*intermediate (gate+up), not intermediate

float4_e2m1fn_x2 packs 2 values per byte along K, not N.
The GEMM output N dimension is the logical N from mat_b.shape[2],
not 2x packed. Previous n_dim*2 was wrong — it accidentally worked
in the test because intermediate_size*2 == 2*intermediate_size.
Real model with N=9216 exposed the bug.
This commit is contained in:
2026-05-16 19:07:08 +00:00
parent f7e29fdf1e
commit 28788c6f55
2 changed files with 3 additions and 6 deletions

View File

@@ -388,14 +388,11 @@ def run_nvfp4_grouped_gemm(
no cute.compile() in the forward path.
"""
num_experts = mat_b.shape[0]
n_dim = mat_b.shape[2] # packed N (in float4_e2m1fn_x2 elements)
# float4_e2m1fn_x2 packs 2 FP4 values per storage element
# The GEMM accumulates in BF16, so output has 2x the packed dimension
n_dim_logical = n_dim * 2
n_dim = mat_b.shape[2] # N dimension (logical, not packed float4_e2m1fn_x2 packs along K, not N)
tokens_sum = mat_a.shape[0]
device = mat_a.device
out = torch.zeros(tokens_sum, n_dim_logical, dtype=torch.bfloat16, device=device)
out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=device)
compiled, kernel, max_active_clusters = _get_compiled_kernel(
num_experts, device, mma_tiler_mn, cluster_shape_mn

View File

@@ -113,7 +113,7 @@ def make_dummy_runner(num_experts=32, hidden_size=7168, intermediate_size=3072,
def rand_sf(*shape, device="cuda"):
return torch.rand(shape, dtype=torch.float16, device=device).to(torch.float8_e4m3fn)
l1_fp4 = [rand_fp4(3584, intermediate_size, device=device) for _ in range(num_experts)]
l1_fp4 = [rand_fp4(3584, intermediate_size * 2, device=device) for _ in range(num_experts)]
l1_sf = [rand_sf(3584 // 16, intermediate_size * 2, device=device) for _ in range(num_experts)]
l1_gs = [0.1] * num_experts
l2_fp4 = [rand_fp4(1536, hidden_size // 2, device=device) for _ in range(num_experts)]