fix: compute K_packed/N_packed before passing to _get_compiled_kernel

This commit is contained in:
2026-05-16 20:00:35 +00:00
parent caf93d6c45
commit 79281b6fda

View File

@@ -387,7 +387,9 @@ def run_nvfp4_grouped_gemm(
no cute.compile() in the forward path.
"""
num_experts = mat_b.shape[0]
n_dim = mat_b.shape[2] # N dimension (logical, not packed — float4_e2m1fn_x2 packs along K, not N)
K_packed = mat_a.shape[1]
N_packed = mat_b.shape[2] # N dimension (logical, not packed — float4_e2m1fn_x2 packs along K, not N)
n_dim = N_packed
tokens_sum = mat_a.shape[0]
device = mat_a.device