Files
nvfp4-megamoe-kernel/cutedsl/moe_pipeline.py

268 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Full NVFP4 MoE pipeline using CuTeDSL ScaledGroupedGemmKernel.
Data flow (NVFP4-native, BF16 only where required):
1. BF16 hidden_states → quantize to NVFP4 (stage_activation)
2. L1 GEMM: NVFP4 × NVFP4 → BF16 output (gate+up)
3. SiLU(gate) * up → BF16 activated (nonlinear requires BF16)
4. Re-quantize activated → NVFP4 (stage_activation)
5. L2 GEMM: NVFP4 × NVFP4 → BF16 output (down_proj)
6. Scatter with routing weights → BF16 output
Both GEMMs are fully NVFP4: A in float4_e2m1fn_x2, B in float4_e2m1fn_x2,
block scales in float8_e4m3fn, global scales in float32.
"""
import torch
from cutedsl.bridge import (
quantize_to_nvfp4,
quantize_weight_to_nvfp4,
assemble_scales_2d_side,
assemble_scales_3d_side,
make_b_k_major,
compute_expert_offsets,
run_nvfp4_grouped_gemm,
)
def stage_activation(x_bf16):
"""Quantize BF16 activation to NVFP4.
This is the NVFP4-native equivalent of the old stage_activation.
Keeps data in FP4 as long as possible — only leaves NVFP4 for nonlinear ops.
Returns (x_fp4, x_sf, global_scale) where:
x_fp4: float4_e2m1fn_x2 (native PyTorch FP4)
x_sf: float8_e4m3fn block scales
global_scale: float32 scalar
"""
return quantize_to_nvfp4(x_bf16)
def quantize_weight(w_bf16):
"""Quantize BF16 weight to NVFP4.
Weight is (K, N) where K is the input/hidden dim (packed dimension).
Returns (w_fp4, w_sf, global_scale).
"""
return quantize_weight_to_nvfp4(w_bf16)
def prepare_nvfp4_moe_weights(nvfp4_tensors, layer_idx, expert_indices):
"""Load NVFP4 checkpoint weights and prepare for the grouped GEMM.
Dequantizes checkpoint NVFP4 → BF16 → re-quantizes to our native format.
This round-trip ensures our FP4 packing convention matches the kernel.
Future optimization: load checkpoint FP4 bytes directly into
float4_e2m1fn_x2 tensors without the BF16 round-trip.
Returns dict with l1 and l2 weight info per expert.
"""
from tests.layertest import dequantize_nvfp4_weight, DEVICE
l1_weights = [] # gate+up fused, (K, N) = (hidden, intermediate)
l2_weights = [] # down, (K, N) = (intermediate, hidden)
for e in expert_indices:
# L1: gate + up
gate_w_bf16 = dequantize_nvfp4_weight(
nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE),
nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE),
nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"].item(),
)
up_w_bf16 = dequantize_nvfp4_weight(
nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"].to(DEVICE),
nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE),
nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"].item(),
)
# Fuse gate+up: (6144, 7168) → transpose to (7168, 6144) for weight quantization
fused_l1 = torch.cat([gate_w_bf16, up_w_bf16], dim=0) # (6144, 7168)
l1_w_bf16 = fused_l1.T # (7168, 6144) — K=7168, N=6144
l1_weights.append(l1_w_bf16)
# L2: down
down_w_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight"
if down_w_key in nvfp4_tensors:
down_w_bf16 = dequantize_nvfp4_weight(
nvfp4_tensors[down_w_key].to(DEVICE),
nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE),
nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"].item(),
)
# down_proj is (7168, 3072) → transpose to (3072, 7168) for K=intermediate
l2_w_bf16 = down_w_bf16.T # (3072, 7168) — K=3072, N=7168
else:
# Expert 211 has no down_proj
l2_w_bf16 = torch.zeros(3072, 7168, dtype=torch.bfloat16, device=DEVICE)
l2_weights.append(l2_w_bf16)
# Quantize all weights to NVFP4
l1_fp4, l1_sf, l1_gs = [], [], []
l2_fp4, l2_sf, l2_gs = [], [], []
for l1_w, l2_w in zip(l1_weights, l2_weights):
w_fp4, w_sf, w_gs = quantize_weight(l1_w)
l1_fp4.append(w_fp4)
l1_sf.append(w_sf)
l1_gs.append(w_gs)
w_fp4, w_sf, w_gs = quantize_weight(l2_w)
l2_fp4.append(w_fp4)
l2_sf.append(w_sf)
l2_gs.append(w_gs)
return {
'l1_fp4': l1_fp4, 'l1_sf': l1_sf, 'l1_gs': l1_gs,
'l2_fp4': l2_fp4, 'l2_sf': l2_sf, 'l2_gs': l2_gs,
}
def run_nvfp4_moe(
hidden_states, # (num_tokens, hidden_size) BF16
expert_ids, # (num_tokens, top_k) int32
expert_weights, # (num_tokens, top_k) float32
weights, # dict from prepare_nvfp4_moe_weights
expert_indices, # list of expert IDs
swiglu_limit=None, # Optional clamp for SiLU output
):
"""Run the full NVFP4 MoE forward pass.
NVFP4-native pipeline:
1. Quantize activation → NVFP4
2. L1 GEMM (NVFP4 × NVFP4 → BF16)
3. SiLU(gate) * up (BF16 — nonlinear requires BF16)
4. Re-quantize → NVFP4
5. L2 GEMM (NVFP4 × NVFP4 → BF16)
6. Scatter with routing weights → BF16
Returns: (num_tokens, hidden_size) BF16
"""
num_tokens, hidden_size = hidden_states.shape
top_k = expert_ids.shape[1]
device = hidden_states.device
# ── Build slot-based routing ──
expert_token_lists = {e: [] for e in expert_indices}
for t in range(num_tokens):
for k in range(top_k):
e = expert_ids[t, k].item()
if e in expert_token_lists:
expert_token_lists[e].append(t)
tokens_per_expert = [len(expert_token_lists[e]) for e in expert_indices]
num_experts = len(expert_indices)
# Slot-major activation: [expert0_tokens | expert1_tokens | ...]
slot_hidden = torch.cat([
hidden_states[expert_token_lists[e]] for e in expert_indices
], dim=0) if any(tpe > 0 for tpe in tokens_per_expert) else torch.zeros(0, hidden_size, dtype=torch.bfloat16, device=device)
num_slots = slot_hidden.shape[0]
if num_slots == 0:
return torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16, device=device)
expert_offsets = compute_expert_offsets(tokens_per_expert, num_experts)
# ════════════════════════════════════════════════════════════════
# L1: gate + up projection (NVFP4 × NVFP4 → BF16)
# ════════════════════════════════════════════════════════════════
# Quantize activation to NVFP4
x_fp4, x_sf, x_igs = stage_activation(slot_hidden)
# Stack L1 weights and convert to K-major
l1_mat_b = make_b_k_major(torch.stack(weights['l1_fp4']))
# Assemble scales
x_sf_parts = []
offset = 0
for tpe in tokens_per_expert:
x_sf_parts.append(x_sf[offset:offset+tpe])
offset += tpe
l1_scale_a = assemble_scales_2d_side(x_sf_parts)
l1_scale_b = assemble_scales_3d_side(weights['l1_sf'])
# Global scales: alpha = igs * weight_gs for each expert
l1_global_scale_a = torch.tensor([x_igs] * num_experts, dtype=torch.float32, device=device)
l1_global_scale_b = torch.tensor(weights['l1_gs'], dtype=torch.float32, device=device)
print(f" L1 global_scale_a: {l1_global_scale_a.tolist()}", flush=True)
print(f" L1 global_scale_b: {l1_global_scale_b.tolist()}", flush=True)
print(f" alpha (a*b): {(l1_global_scale_a * l1_global_scale_b).tolist()}", flush=True)
# Run L1 GEMM
l1_out = run_nvfp4_grouped_gemm(
mat_a=x_fp4, mat_b=l1_mat_b,
scale_a=l1_scale_a, scale_b=l1_scale_b,
expert_offsets=expert_offsets,
global_scale_a=l1_global_scale_a, global_scale_b=l1_global_scale_b,
) # (num_slots, 2*intermediate) BF16
print(f" L1 GEMM output: shape={l1_out.shape}, amax={l1_out.abs().amax().item():.4f}", flush=True)
# ════════════════════════════════════════════════════════════════
# SiLU(gate) * up (BF16 — nonlinear requires BF16)
# ════════════════════════════════════════════════════════════════
# L1 output is (tokens, 2*intermediate) — gate and up fused
intermediate_size = l1_out.shape[1] // 2
gate = l1_out[:, :intermediate_size]
up = l1_out[:, intermediate_size:]
print(f" gate: shape={gate.shape}, amax={gate.abs().amax().item():.4f}", flush=True)
print(f" up: shape={up.shape}, amax={up.abs().amax().item():.4f}", flush=True)
gate_silu = torch.nn.functional.silu(gate)
if swiglu_limit is not None:
gate_silu = gate_silu.clamp(max=swiglu_limit)
up = up.clamp(min=-swiglu_limit, max=swiglu_limit)
activated = gate_silu * up # (num_slots, intermediate) BF16
print(f" After SiLU(gate)*up: shape={activated.shape}, amax={activated.abs().amax().item():.4f}", flush=True)
# ════════════════════════════════════════════════════════════════
# L2: down projection (NVFP4 × NVFP4 → BF16)
# ════════════════════════════════════════════════════════════════
# Re-quantize activated → NVFP4
l2_x_fp4, l2_x_sf, l2_x_igs = stage_activation(activated)
# Stack L2 weights
l2_mat_b = make_b_k_major(torch.stack(weights['l2_fp4']))
# Assemble L2 scales
l2_sf_parts = []
offset = 0
for tpe in tokens_per_expert:
l2_sf_parts.append(l2_x_sf[offset:offset+tpe])
offset += tpe
l2_scale_a = assemble_scales_2d_side(l2_sf_parts)
l2_scale_b = assemble_scales_3d_side(weights['l2_sf'])
# Global scales
l2_global_scale_a = torch.tensor([l2_x_igs] * num_experts, dtype=torch.float32, device=device)
l2_global_scale_b = torch.tensor(weights['l2_gs'], dtype=torch.float32, device=device)
# Run L2 GEMM
l2_out = run_nvfp4_grouped_gemm(
mat_a=l2_x_fp4, mat_b=l2_mat_b,
scale_a=l2_scale_a, scale_b=l2_scale_b,
expert_offsets=expert_offsets,
global_scale_a=l2_global_scale_a, global_scale_b=l2_global_scale_b,
) # (num_slots, hidden_size) BF16
# ════════════════════════════════════════════════════════════════
# Scatter with routing weights → final output
# ════════════════════════════════════════════════════════════════
y = torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16, device=device)
slot_idx = 0
for e in expert_indices:
for t in expert_token_lists[e]:
# Find which top-k slot this is for this token
for k in range(top_k):
if expert_ids[t, k].item() == e:
w = expert_weights[t, k].item()
y[t] += w * l2_out[slot_idx]
break
slot_idx += 1
return y