diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py index 3b6a57d5..f62ae2eb 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py @@ -54,10 +54,6 @@ def cutlass_grouped_nvfp4_gemm( print(f"[cutlass_grouped_gemm] tokens={num_tokens} K={K} N={N} " f"experts={num_experts} topk={num_topk}") - # For now, run all tokens through all experts and select using topk_ids - # This is simpler than per-expert gather and works for moderate token counts - # TODO: optimize with per-expert gather for large batch sizes - output = torch.zeros(num_tokens, N, dtype=torch.bfloat16, device=x_fp4.device) for e in range(num_experts): @@ -72,21 +68,19 @@ def cutlass_grouped_nvfp4_gemm( expert_x = x_fp4[token_indices] # (num_expert_tokens, K_half) expert_x_sf = x_sf[token_indices] # (num_expert_tokens, sf_k) expert_w = weights[e] # (N, K_half) - expert_w_sf = weight_sf[e] # (N, sf_k) + expert_w_sf = weight_sf[e] # (N, sf_k) — THIS IS SCALES, NOT WEIGHTS M_expert = token_indices.shape[0] # Run CUTLASS NVFP4 block-scaled GEMM expert_out = cutlass_nvfp4_blockscaled_gemm( expert_x, expert_x_sf, - expert_w.unsqueeze(0).expand(1, N, K_half).reshape(N, K_half) if expert_w.dim() == 2 else expert_w, - expert_w_sf.unsqueeze(0).expand(1, N, K_half).reshape(N, K_half) if expert_w_sf.dim() == 2 else expert_w_sf, + expert_w, expert_w_sf, # Pass directly — already (N, K_half) and (N, sf_k) M_expert, N, K, ) # (M_expert, N) bfloat16 # Scatter back with routing weights for t_idx, token_idx in enumerate(token_indices): - # Find which topk slot(s) route to this expert for k_idx in range(num_topk): if topk_ids[token_idx, k_idx] == e: output[token_idx] += topk_weights[token_idx, k_idx] * expert_out[t_idx]