From 28788c6f558cc5c5158215a828810480e2f71b0d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 16 May 2026 19:07:08 +0000 Subject: [PATCH] fix: L1 weight N dimension is 2*intermediate (gate+up), not intermediate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit float4_e2m1fn_x2 packs 2 values per byte along K, not N. The GEMM output N dimension is the logical N from mat_b.shape[2], not 2x packed. Previous n_dim*2 was wrong — it accidentally worked in the test because intermediate_size*2 == 2*intermediate_size. Real model with N=9216 exposed the bug. --- cutedsl/bridge.py | 7 ++----- tests/cudagraph_test.py | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/cutedsl/bridge.py b/cutedsl/bridge.py index f8bc021d..bd9969e5 100644 --- a/cutedsl/bridge.py +++ b/cutedsl/bridge.py @@ -388,14 +388,11 @@ def run_nvfp4_grouped_gemm( no cute.compile() in the forward path. """ num_experts = mat_b.shape[0] - n_dim = mat_b.shape[2] # packed N (in float4_e2m1fn_x2 elements) - # float4_e2m1fn_x2 packs 2 FP4 values per storage element - # The GEMM accumulates in BF16, so output has 2x the packed dimension - n_dim_logical = n_dim * 2 + n_dim = mat_b.shape[2] # N dimension (logical, not packed — float4_e2m1fn_x2 packs along K, not N) tokens_sum = mat_a.shape[0] device = mat_a.device - out = torch.zeros(tokens_sum, n_dim_logical, dtype=torch.bfloat16, device=device) + out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=device) compiled, kernel, max_active_clusters = _get_compiled_kernel( num_experts, device, mma_tiler_mn, cluster_shape_mn diff --git a/tests/cudagraph_test.py b/tests/cudagraph_test.py index 6b949d7d..ff929fdd 100644 --- a/tests/cudagraph_test.py +++ b/tests/cudagraph_test.py @@ -113,7 +113,7 @@ def make_dummy_runner(num_experts=32, hidden_size=7168, intermediate_size=3072, def rand_sf(*shape, device="cuda"): return torch.rand(shape, dtype=torch.float16, device=device).to(torch.float8_e4m3fn) - l1_fp4 = [rand_fp4(3584, intermediate_size, device=device) for _ in range(num_experts)] + l1_fp4 = [rand_fp4(3584, intermediate_size * 2, device=device) for _ in range(num_experts)] l1_sf = [rand_sf(3584 // 16, intermediate_size * 2, device=device) for _ in range(num_experts)] l1_gs = [0.1] * num_experts l2_fp4 = [rand_fp4(1536, hidden_size // 2, device=device) for _ in range(num_experts)]