fix: L1 weight N dimension is 2*intermediate (gate+up), not intermediate
float4_e2m1fn_x2 packs 2 values per byte along K, not N. The GEMM output N dimension is the logical N from mat_b.shape[2], not 2x packed. Previous n_dim*2 was wrong — it accidentally worked in the test because intermediate_size*2 == 2*intermediate_size. Real model with N=9216 exposed the bug.
This commit is contained in:
@@ -388,14 +388,11 @@ def run_nvfp4_grouped_gemm(
|
||||
no cute.compile() in the forward path.
|
||||
"""
|
||||
num_experts = mat_b.shape[0]
|
||||
n_dim = mat_b.shape[2] # packed N (in float4_e2m1fn_x2 elements)
|
||||
# float4_e2m1fn_x2 packs 2 FP4 values per storage element
|
||||
# The GEMM accumulates in BF16, so output has 2x the packed dimension
|
||||
n_dim_logical = n_dim * 2
|
||||
n_dim = mat_b.shape[2] # N dimension (logical, not packed — float4_e2m1fn_x2 packs along K, not N)
|
||||
tokens_sum = mat_a.shape[0]
|
||||
device = mat_a.device
|
||||
|
||||
out = torch.zeros(tokens_sum, n_dim_logical, dtype=torch.bfloat16, device=device)
|
||||
out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=device)
|
||||
|
||||
compiled, kernel, max_active_clusters = _get_compiled_kernel(
|
||||
num_experts, device, mma_tiler_mn, cluster_shape_mn
|
||||
|
||||
@@ -113,7 +113,7 @@ def make_dummy_runner(num_experts=32, hidden_size=7168, intermediate_size=3072,
|
||||
def rand_sf(*shape, device="cuda"):
|
||||
return torch.rand(shape, dtype=torch.float16, device=device).to(torch.float8_e4m3fn)
|
||||
|
||||
l1_fp4 = [rand_fp4(3584, intermediate_size, device=device) for _ in range(num_experts)]
|
||||
l1_fp4 = [rand_fp4(3584, intermediate_size * 2, device=device) for _ in range(num_experts)]
|
||||
l1_sf = [rand_sf(3584 // 16, intermediate_size * 2, device=device) for _ in range(num_experts)]
|
||||
l1_gs = [0.1] * num_experts
|
||||
l2_fp4 = [rand_fp4(1536, hidden_size // 2, device=device) for _ in range(num_experts)]
|
||||
|
||||
Reference in New Issue
Block a user