Files
nvfp4-megamoe-kernel/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py

110 lines
4.7 KiB
Python

"""
CUTLASS NVFP4 Block-Scaled GEMM — Native Blackwell SM100 kernel.
Uses the pre-compiled PyTorch CUDA extension (cutlass_nvfp4_gemm._C)
which invokes native mxf8f6f4.block_scale tensor core instructions.
"""
import os
import torch
MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0"))
try:
from cutlass_nvfp4_gemm import _C
_CUTLASS_AVAILABLE = True
except ImportError:
_CUTLASS_AVAILABLE = False
def cutlass_nvfp4_blockscaled_gemm(
A_packed, # (M, K_half) int8 packed E2M1
SFA, # scale factors for A (float8_e4m3fn)
B_packed, # (K_half, N) int8 packed E2M1, column-major for CUTLASS
SFB, # scale factors for B (sf_k, N) float8_e4m3fn, column-major for CUTLASS
M, N, K, # Problem dimensions (K in FP4 elements)
alpha=1.0, # fp32 scalar applied in epilogue: D = alpha * A @ B + beta * C
):
"""Single NVFP4 block-scaled GEMM using CUTLASS."""
if not _CUTLASS_AVAILABLE:
raise RuntimeError("CUTLASS NVFP4 GEMM extension not available")
return _C.forward(A_packed, SFA, B_packed, SFB, M, N, K, alpha)
def cutlass_grouped_nvfp4_gemm(
x_fp4, # (num_tokens, K_half) int8 packed E2M1
x_sf, # (num_tokens, sf_k) float8_e4m3fn block scales
weights, # (E_per_rank, K_half, N) int8 packed E2M1, column-major for CUTLASS
weight_sf, # (E_per_rank, sf_k, N) float8_e4m3fn, column-major for CUTLASS
topk_ids, # (num_tokens, NUM_TOPK) int32
topk_weights, # (num_tokens, NUM_TOPK) float32
alpha=1.0, # fp32 scalar: D = alpha * A @ B (from stage_activation global scale)
):
"""Per-expert grouped GEMM for MoE dispatch using CUTLASS NVFP4.
For each expert, gather the tokens routed to it, run the block-scaled GEMM,
then scatter results back with routing weights.
"""
num_tokens = x_fp4.shape[0]
K_half = x_fp4.shape[1]
K = K_half * 2 # Actual K dimension (2 FP4 per byte)
# Weights are (E, K_half, N) column-major (transposed at load time for CUTLASS ColumnMajor B)
N = weights.shape[2] # Output dimension
num_experts = weights.shape[0]
num_topk = topk_ids.shape[1]
if MEGA_MOE_DEBUG:
print(f"[cutlass_grouped_gemm] tokens={num_tokens} K={K} N={N} "
f"experts={num_experts} topk={num_topk}")
output = torch.zeros(num_tokens, N, dtype=torch.bfloat16, device=x_fp4.device)
for e in range(num_experts):
# Find tokens routed to this expert
expert_mask = (topk_ids == e) # (num_tokens, num_topk)
token_indices = expert_mask.any(dim=1).nonzero(as_tuple=True)[0]
if token_indices.numel() == 0:
continue
# Gather tokens for this expert
expert_x = x_fp4[token_indices] # (num_expert_tokens, K_half)
expert_x_sf = x_sf[token_indices] # (num_expert_tokens, sf_k)
expert_w = weights[e] # (K_half, N) column-major for CUTLASS
expert_w_sf = weight_sf[e] # (sf_k, N) column-major for CUTLASS
M_expert = token_indices.shape[0]
# Run CUTLASS NVFP4 block-scaled GEMM
expert_out = cutlass_nvfp4_blockscaled_gemm(
expert_x, expert_x_sf,
expert_w, expert_w_sf, # Pass directly — already (N, K_half) and (N, sf_k)
M_expert, N, K,
alpha=alpha,
) # (M_expert, N) bfloat16
# Check for CUDA errors after each expert GEMM
torch.cuda.current_stream().synchronize()
# Hard-fail on NaN/Inf — silent skip was hiding bugs
if torch.isnan(expert_out).any() or torch.isinf(expert_out).any():
raise RuntimeError(
f"expert {e} of {num_experts}: GEMM emitted NaN/Inf. "
f"M={M_expert} N={N} K={K} | "
f"x abs range [{expert_x.view(torch.int8).abs().max().item()}], "
f"x_sf range [{expert_x_sf.to(torch.float32).min().item():.4e}, "
f"{expert_x_sf.to(torch.float32).max().item():.4e}], "
f"w_sf range [{expert_w_sf.to(torch.float32).min().item():.4e}, "
f"{expert_w_sf.to(torch.float32).max().item():.4e}], "
f"x_sf nan_frac={torch.isnan(expert_x_sf.to(torch.float32)).float().mean().item():.4f}, "
f"w_sf nan_frac={torch.isnan(expert_w_sf.to(torch.float32)).float().mean().item():.4f}"
)
# Scatter back with routing weights
for t_idx, token_idx in enumerate(token_indices):
for k_idx in range(num_topk):
if topk_ids[token_idx, k_idx] == e:
output[token_idx] += topk_weights[token_idx, k_idx] * expert_out[t_idx]
return output