Fix kernel.py: remove broken expand on scale factors (was expanding sf to weight size)

This commit is contained in:
2026-05-14 10:36:16 +00:00
parent 84becfac93
commit 869151d211

View File

@@ -54,10 +54,6 @@ def cutlass_grouped_nvfp4_gemm(
print(f"[cutlass_grouped_gemm] tokens={num_tokens} K={K} N={N} "
f"experts={num_experts} topk={num_topk}")
# For now, run all tokens through all experts and select using topk_ids
# This is simpler than per-expert gather and works for moderate token counts
# TODO: optimize with per-expert gather for large batch sizes
output = torch.zeros(num_tokens, N, dtype=torch.bfloat16, device=x_fp4.device)
for e in range(num_experts):
@@ -72,21 +68,19 @@ def cutlass_grouped_nvfp4_gemm(
expert_x = x_fp4[token_indices] # (num_expert_tokens, K_half)
expert_x_sf = x_sf[token_indices] # (num_expert_tokens, sf_k)
expert_w = weights[e] # (N, K_half)
expert_w_sf = weight_sf[e] # (N, sf_k)
expert_w_sf = weight_sf[e] # (N, sf_k) — THIS IS SCALES, NOT WEIGHTS
M_expert = token_indices.shape[0]
# Run CUTLASS NVFP4 block-scaled GEMM
expert_out = cutlass_nvfp4_blockscaled_gemm(
expert_x, expert_x_sf,
expert_w.unsqueeze(0).expand(1, N, K_half).reshape(N, K_half) if expert_w.dim() == 2 else expert_w,
expert_w_sf.unsqueeze(0).expand(1, N, K_half).reshape(N, K_half) if expert_w_sf.dim() == 2 else expert_w_sf,
expert_w, expert_w_sf, # Pass directly — already (N, K_half) and (N, sf_k)
M_expert, N, K,
) # (M_expert, N) bfloat16
# Scatter back with routing weights
for t_idx, token_idx in enumerate(token_indices):
# Find which topk slot(s) route to this expert
for k_idx in range(num_topk):
if topk_ids[token_idx, k_idx] == e:
output[token_idx] += topk_weights[token_idx, k_idx] * expert_out[t_idx]