diff --git a/cutedsl/bridge.py b/cutedsl/bridge.py index 08792612..7f54d407 100644 --- a/cutedsl/bridge.py +++ b/cutedsl/bridge.py @@ -387,7 +387,9 @@ def run_nvfp4_grouped_gemm( no cute.compile() in the forward path. """ num_experts = mat_b.shape[0] - n_dim = mat_b.shape[2] # N dimension (logical, not packed — float4_e2m1fn_x2 packs along K, not N) + K_packed = mat_a.shape[1] + N_packed = mat_b.shape[2] # N dimension (logical, not packed — float4_e2m1fn_x2 packs along K, not N) + n_dim = N_packed tokens_sum = mat_a.shape[0] device = mat_a.device