fix: compute K_packed/N_packed before passing to _get_compiled_kernel
This commit is contained in:
@@ -387,7 +387,9 @@ def run_nvfp4_grouped_gemm(
|
||||
no cute.compile() in the forward path.
|
||||
"""
|
||||
num_experts = mat_b.shape[0]
|
||||
n_dim = mat_b.shape[2] # N dimension (logical, not packed — float4_e2m1fn_x2 packs along K, not N)
|
||||
K_packed = mat_a.shape[1]
|
||||
N_packed = mat_b.shape[2] # N dimension (logical, not packed — float4_e2m1fn_x2 packs along K, not N)
|
||||
n_dim = N_packed
|
||||
tokens_sum = mat_a.shape[0]
|
||||
device = mat_a.device
|
||||
|
||||
|
||||
Reference in New Issue
Block a user