110 lines
4.7 KiB
Python
110 lines
4.7 KiB
Python
"""
|
|
CUTLASS NVFP4 Block-Scaled GEMM — Native Blackwell SM100 kernel.
|
|
|
|
Uses the pre-compiled PyTorch CUDA extension (cutlass_nvfp4_gemm._C)
|
|
which invokes native mxf8f6f4.block_scale tensor core instructions.
|
|
"""
|
|
|
|
import os
|
|
import torch
|
|
|
|
MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0"))
|
|
|
|
try:
|
|
from cutlass_nvfp4_gemm import _C
|
|
_CUTLASS_AVAILABLE = True
|
|
except ImportError:
|
|
_CUTLASS_AVAILABLE = False
|
|
|
|
|
|
def cutlass_nvfp4_blockscaled_gemm(
|
|
A_packed, # (M, K_half) int8 packed E2M1
|
|
SFA, # scale factors for A (float8_e4m3fn)
|
|
B_packed, # (K_half, N) int8 packed E2M1, column-major for CUTLASS
|
|
SFB, # scale factors for B (sf_k, N) float8_e4m3fn, column-major for CUTLASS
|
|
M, N, K, # Problem dimensions (K in FP4 elements)
|
|
alpha=1.0, # fp32 scalar applied in epilogue: D = alpha * A @ B + beta * C
|
|
):
|
|
"""Single NVFP4 block-scaled GEMM using CUTLASS."""
|
|
if not _CUTLASS_AVAILABLE:
|
|
raise RuntimeError("CUTLASS NVFP4 GEMM extension not available")
|
|
return _C.forward(A_packed, SFA, B_packed, SFB, M, N, K, alpha)
|
|
|
|
|
|
def cutlass_grouped_nvfp4_gemm(
|
|
x_fp4, # (num_tokens, K_half) int8 packed E2M1
|
|
x_sf, # (num_tokens, sf_k) float8_e4m3fn block scales
|
|
weights, # (E_per_rank, K_half, N) int8 packed E2M1, column-major for CUTLASS
|
|
weight_sf, # (E_per_rank, sf_k, N) float8_e4m3fn, column-major for CUTLASS
|
|
topk_ids, # (num_tokens, NUM_TOPK) int32
|
|
topk_weights, # (num_tokens, NUM_TOPK) float32
|
|
alpha=1.0, # fp32 scalar: D = alpha * A @ B (from stage_activation global scale)
|
|
):
|
|
"""Per-expert grouped GEMM for MoE dispatch using CUTLASS NVFP4.
|
|
|
|
For each expert, gather the tokens routed to it, run the block-scaled GEMM,
|
|
then scatter results back with routing weights.
|
|
"""
|
|
num_tokens = x_fp4.shape[0]
|
|
K_half = x_fp4.shape[1]
|
|
K = K_half * 2 # Actual K dimension (2 FP4 per byte)
|
|
# Weights are (E, K_half, N) column-major (transposed at load time for CUTLASS ColumnMajor B)
|
|
N = weights.shape[2] # Output dimension
|
|
num_experts = weights.shape[0]
|
|
num_topk = topk_ids.shape[1]
|
|
|
|
if MEGA_MOE_DEBUG:
|
|
print(f"[cutlass_grouped_gemm] tokens={num_tokens} K={K} N={N} "
|
|
f"experts={num_experts} topk={num_topk}")
|
|
|
|
output = torch.zeros(num_tokens, N, dtype=torch.bfloat16, device=x_fp4.device)
|
|
|
|
for e in range(num_experts):
|
|
# Find tokens routed to this expert
|
|
expert_mask = (topk_ids == e) # (num_tokens, num_topk)
|
|
token_indices = expert_mask.any(dim=1).nonzero(as_tuple=True)[0]
|
|
|
|
if token_indices.numel() == 0:
|
|
continue
|
|
|
|
# Gather tokens for this expert
|
|
expert_x = x_fp4[token_indices] # (num_expert_tokens, K_half)
|
|
expert_x_sf = x_sf[token_indices] # (num_expert_tokens, sf_k)
|
|
expert_w = weights[e] # (K_half, N) column-major for CUTLASS
|
|
expert_w_sf = weight_sf[e] # (sf_k, N) column-major for CUTLASS
|
|
|
|
M_expert = token_indices.shape[0]
|
|
|
|
# Run CUTLASS NVFP4 block-scaled GEMM
|
|
expert_out = cutlass_nvfp4_blockscaled_gemm(
|
|
expert_x, expert_x_sf,
|
|
expert_w, expert_w_sf, # Pass directly — already (N, K_half) and (N, sf_k)
|
|
M_expert, N, K,
|
|
alpha=alpha,
|
|
) # (M_expert, N) bfloat16
|
|
|
|
# Check for CUDA errors after each expert GEMM
|
|
torch.cuda.current_stream().synchronize()
|
|
|
|
# Hard-fail on NaN/Inf — silent skip was hiding bugs
|
|
if torch.isnan(expert_out).any() or torch.isinf(expert_out).any():
|
|
raise RuntimeError(
|
|
f"expert {e} of {num_experts}: GEMM emitted NaN/Inf. "
|
|
f"M={M_expert} N={N} K={K} | "
|
|
f"x abs range [{expert_x.view(torch.int8).abs().max().item()}], "
|
|
f"x_sf range [{expert_x_sf.to(torch.float32).min().item():.4e}, "
|
|
f"{expert_x_sf.to(torch.float32).max().item():.4e}], "
|
|
f"w_sf range [{expert_w_sf.to(torch.float32).min().item():.4e}, "
|
|
f"{expert_w_sf.to(torch.float32).max().item():.4e}], "
|
|
f"x_sf nan_frac={torch.isnan(expert_x_sf.to(torch.float32)).float().mean().item():.4f}, "
|
|
f"w_sf nan_frac={torch.isnan(expert_w_sf.to(torch.float32)).float().mean().item():.4f}"
|
|
)
|
|
|
|
# Scatter back with routing weights
|
|
for t_idx, token_idx in enumerate(token_indices):
|
|
for k_idx in range(num_topk):
|
|
if topk_ids[token_idx, k_idx] == e:
|
|
output[token_idx] += topk_weights[token_idx, k_idx] * expert_out[t_idx]
|
|
|
|
return output
|