diff --git a/cutedsl/bridge.py b/cutedsl/bridge.py index f8bc021d..bd9969e5 100644 --- a/cutedsl/bridge.py +++ b/cutedsl/bridge.py @@ -388,14 +388,11 @@ def run_nvfp4_grouped_gemm( no cute.compile() in the forward path. """ num_experts = mat_b.shape[0] - n_dim = mat_b.shape[2] # packed N (in float4_e2m1fn_x2 elements) - # float4_e2m1fn_x2 packs 2 FP4 values per storage element - # The GEMM accumulates in BF16, so output has 2x the packed dimension - n_dim_logical = n_dim * 2 + n_dim = mat_b.shape[2] # N dimension (logical, not packed — float4_e2m1fn_x2 packs along K, not N) tokens_sum = mat_a.shape[0] device = mat_a.device - out = torch.zeros(tokens_sum, n_dim_logical, dtype=torch.bfloat16, device=device) + out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=device) compiled, kernel, max_active_clusters = _get_compiled_kernel( num_experts, device, mma_tiler_mn, cluster_shape_mn diff --git a/tests/cudagraph_test.py b/tests/cudagraph_test.py index 6b949d7d..ff929fdd 100644 --- a/tests/cudagraph_test.py +++ b/tests/cudagraph_test.py @@ -113,7 +113,7 @@ def make_dummy_runner(num_experts=32, hidden_size=7168, intermediate_size=3072, def rand_sf(*shape, device="cuda"): return torch.rand(shape, dtype=torch.float16, device=device).to(torch.float8_e4m3fn) - l1_fp4 = [rand_fp4(3584, intermediate_size, device=device) for _ in range(num_experts)] + l1_fp4 = [rand_fp4(3584, intermediate_size * 2, device=device) for _ in range(num_experts)] l1_sf = [rand_sf(3584 // 16, intermediate_size * 2, device=device) for _ in range(num_experts)] l1_gs = [0.1] * num_experts l2_fp4 = [rand_fp4(1536, hidden_size // 2, device=device) for _ in range(num_experts)]