Fix kernel.py: remove broken expand on scale factors (was expanding sf to weight size)
This commit is contained in:
@@ -54,10 +54,6 @@ def cutlass_grouped_nvfp4_gemm(
|
||||
print(f"[cutlass_grouped_gemm] tokens={num_tokens} K={K} N={N} "
|
||||
f"experts={num_experts} topk={num_topk}")
|
||||
|
||||
# For now, run all tokens through all experts and select using topk_ids
|
||||
# This is simpler than per-expert gather and works for moderate token counts
|
||||
# TODO: optimize with per-expert gather for large batch sizes
|
||||
|
||||
output = torch.zeros(num_tokens, N, dtype=torch.bfloat16, device=x_fp4.device)
|
||||
|
||||
for e in range(num_experts):
|
||||
@@ -72,21 +68,19 @@ def cutlass_grouped_nvfp4_gemm(
|
||||
expert_x = x_fp4[token_indices] # (num_expert_tokens, K_half)
|
||||
expert_x_sf = x_sf[token_indices] # (num_expert_tokens, sf_k)
|
||||
expert_w = weights[e] # (N, K_half)
|
||||
expert_w_sf = weight_sf[e] # (N, sf_k)
|
||||
expert_w_sf = weight_sf[e] # (N, sf_k) — THIS IS SCALES, NOT WEIGHTS
|
||||
|
||||
M_expert = token_indices.shape[0]
|
||||
|
||||
# Run CUTLASS NVFP4 block-scaled GEMM
|
||||
expert_out = cutlass_nvfp4_blockscaled_gemm(
|
||||
expert_x, expert_x_sf,
|
||||
expert_w.unsqueeze(0).expand(1, N, K_half).reshape(N, K_half) if expert_w.dim() == 2 else expert_w,
|
||||
expert_w_sf.unsqueeze(0).expand(1, N, K_half).reshape(N, K_half) if expert_w_sf.dim() == 2 else expert_w_sf,
|
||||
expert_w, expert_w_sf, # Pass directly — already (N, K_half) and (N, sf_k)
|
||||
M_expert, N, K,
|
||||
) # (M_expert, N) bfloat16
|
||||
|
||||
# Scatter back with routing weights
|
||||
for t_idx, token_idx in enumerate(token_indices):
|
||||
# Find which topk slot(s) route to this expert
|
||||
for k_idx in range(num_topk):
|
||||
if topk_ids[token_idx, k_idx] == e:
|
||||
output[token_idx] += topk_weights[token_idx, k_idx] * expert_out[t_idx]
|
||||
|
||||
Reference in New Issue
Block a user