feat: full NVFP4 MoE pipeline (L1→SiLU→L2→scatter)
cutedsl/moe_pipeline.py: complete pipeline - stage_activation: BF16 → NVFP4 (keeps data in FP4) - L1 GEMM: NVFP4 × NVFP4 → BF16 (gate+up) - SiLU(gate) * up: BF16 (only nonlinear, can't avoid) - Re-quantize: BF16 → NVFP4 (back to native) - L2 GEMM: NVFP4 × NVFP4 → BF16 (down_proj) - Scatter with routing weights → BF16 output layertest.py: now tests the FULL MoE pipeline against BF16 reference. NVFP4-native: both GEMMs use float4_e2m1fn_x2 for A and B, float8_e4m3fn for block scales, float32 for global scales. BF16 only for SiLU activation and final scatter.
This commit is contained in:
255
cutedsl/moe_pipeline.py
Normal file
255
cutedsl/moe_pipeline.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""
|
||||
Full NVFP4 MoE pipeline using CuTeDSL ScaledGroupedGemmKernel.
|
||||
|
||||
Data flow (NVFP4-native, BF16 only where required):
|
||||
1. BF16 hidden_states → quantize to NVFP4 (stage_activation)
|
||||
2. L1 GEMM: NVFP4 × NVFP4 → BF16 output (gate+up)
|
||||
3. SiLU(gate) * up → BF16 activated (nonlinear requires BF16)
|
||||
4. Re-quantize activated → NVFP4 (stage_activation)
|
||||
5. L2 GEMM: NVFP4 × NVFP4 → BF16 output (down_proj)
|
||||
6. Scatter with routing weights → BF16 output
|
||||
|
||||
Both GEMMs are fully NVFP4: A in float4_e2m1fn_x2, B in float4_e2m1fn_x2,
|
||||
block scales in float8_e4m3fn, global scales in float32.
|
||||
"""
|
||||
import torch
|
||||
|
||||
from cutedsl.bridge import (
|
||||
quantize_to_nvfp4,
|
||||
quantize_weight_to_nvfp4,
|
||||
assemble_scales_2d_side,
|
||||
assemble_scales_3d_side,
|
||||
make_b_k_major,
|
||||
compute_expert_offsets,
|
||||
run_nvfp4_grouped_gemm,
|
||||
)
|
||||
|
||||
|
||||
def stage_activation(x_bf16):
|
||||
"""Quantize BF16 activation to NVFP4.
|
||||
|
||||
This is the NVFP4-native equivalent of the old stage_activation.
|
||||
Keeps data in FP4 as long as possible — only leaves NVFP4 for nonlinear ops.
|
||||
|
||||
Returns (x_fp4, x_sf, global_scale) where:
|
||||
x_fp4: float4_e2m1fn_x2 (native PyTorch FP4)
|
||||
x_sf: float8_e4m3fn block scales
|
||||
global_scale: float32 scalar
|
||||
"""
|
||||
return quantize_to_nvfp4(x_bf16)
|
||||
|
||||
|
||||
def quantize_weight(w_bf16):
|
||||
"""Quantize BF16 weight to NVFP4.
|
||||
|
||||
Weight is (K, N) where K is the input/hidden dim (packed dimension).
|
||||
|
||||
Returns (w_fp4, w_sf, global_scale).
|
||||
"""
|
||||
return quantize_weight_to_nvfp4(w_bf16)
|
||||
|
||||
|
||||
def prepare_nvfp4_moe_weights(nvfp4_tensors, layer_idx, expert_indices):
|
||||
"""Load NVFP4 checkpoint weights and prepare for the grouped GEMM.
|
||||
|
||||
Dequantizes checkpoint NVFP4 → BF16 → re-quantizes to our native format.
|
||||
This round-trip ensures our FP4 packing convention matches the kernel.
|
||||
|
||||
Future optimization: load checkpoint FP4 bytes directly into
|
||||
float4_e2m1fn_x2 tensors without the BF16 round-trip.
|
||||
|
||||
Returns dict with l1 and l2 weight info per expert.
|
||||
"""
|
||||
from tests.layertest import dequantize_nvfp4_weight, DEVICE
|
||||
|
||||
l1_weights = [] # gate+up fused, (K, N) = (hidden, intermediate)
|
||||
l2_weights = [] # down, (K, N) = (intermediate, hidden)
|
||||
|
||||
for e in expert_indices:
|
||||
# L1: gate + up
|
||||
gate_w_bf16 = dequantize_nvfp4_weight(
|
||||
nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE),
|
||||
nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE),
|
||||
nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"].item(),
|
||||
)
|
||||
up_w_bf16 = dequantize_nvfp4_weight(
|
||||
nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"].to(DEVICE),
|
||||
nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE),
|
||||
nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"].item(),
|
||||
)
|
||||
|
||||
# Fuse gate+up: (6144, 7168) → transpose to (7168, 6144) for weight quantization
|
||||
fused_l1 = torch.cat([gate_w_bf16, up_w_bf16], dim=0) # (6144, 7168)
|
||||
l1_w_bf16 = fused_l1.T # (7168, 6144) — K=7168, N=6144
|
||||
l1_weights.append(l1_w_bf16)
|
||||
|
||||
# L2: down
|
||||
down_w_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight"
|
||||
if down_w_key in nvfp4_tensors:
|
||||
down_w_bf16 = dequantize_nvfp4_weight(
|
||||
nvfp4_tensors[down_w_key].to(DEVICE),
|
||||
nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE),
|
||||
nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"].item(),
|
||||
)
|
||||
# down_proj is (7168, 3072) → transpose to (3072, 7168) for K=intermediate
|
||||
l2_w_bf16 = down_w_bf16.T # (3072, 7168) — K=3072, N=7168
|
||||
else:
|
||||
# Expert 211 has no down_proj
|
||||
l2_w_bf16 = torch.zeros(3072, 7168, dtype=torch.bfloat16, device=DEVICE)
|
||||
|
||||
l2_weights.append(l2_w_bf16)
|
||||
|
||||
# Quantize all weights to NVFP4
|
||||
l1_fp4, l1_sf, l1_gs = [], [], []
|
||||
l2_fp4, l2_sf, l2_gs = [], [], []
|
||||
|
||||
for l1_w, l2_w in zip(l1_weights, l2_weights):
|
||||
w_fp4, w_sf, w_gs = quantize_weight(l1_w)
|
||||
l1_fp4.append(w_fp4)
|
||||
l1_sf.append(w_sf)
|
||||
l1_gs.append(w_gs)
|
||||
|
||||
w_fp4, w_sf, w_gs = quantize_weight(l2_w)
|
||||
l2_fp4.append(w_fp4)
|
||||
l2_sf.append(w_sf)
|
||||
l2_gs.append(w_gs)
|
||||
|
||||
return {
|
||||
'l1_fp4': l1_fp4, 'l1_sf': l1_sf, 'l1_gs': l1_gs,
|
||||
'l2_fp4': l2_fp4, 'l2_sf': l2_sf, 'l2_gs': l2_gs,
|
||||
}
|
||||
|
||||
|
||||
def run_nvfp4_moe(
|
||||
hidden_states, # (num_tokens, hidden_size) BF16
|
||||
expert_ids, # (num_tokens, top_k) int32
|
||||
expert_weights, # (num_tokens, top_k) float32
|
||||
weights, # dict from prepare_nvfp4_moe_weights
|
||||
expert_indices, # list of expert IDs
|
||||
):
|
||||
"""Run the full NVFP4 MoE forward pass.
|
||||
|
||||
NVFP4-native pipeline:
|
||||
1. Quantize activation → NVFP4
|
||||
2. L1 GEMM (NVFP4 × NVFP4 → BF16)
|
||||
3. SiLU(gate) * up (BF16 — nonlinear requires BF16)
|
||||
4. Re-quantize → NVFP4
|
||||
5. L2 GEMM (NVFP4 × NVFP4 → BF16)
|
||||
6. Scatter with routing weights → BF16
|
||||
|
||||
Returns: (num_tokens, hidden_size) BF16
|
||||
"""
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
top_k = expert_ids.shape[1]
|
||||
device = hidden_states.device
|
||||
|
||||
# ── Build slot-based routing ──
|
||||
expert_token_lists = {e: [] for e in expert_indices}
|
||||
for t in range(num_tokens):
|
||||
for k in range(top_k):
|
||||
e = expert_ids[t, k].item()
|
||||
if e in expert_token_lists:
|
||||
expert_token_lists[e].append(t)
|
||||
|
||||
tokens_per_expert = [len(expert_token_lists[e]) for e in expert_indices]
|
||||
num_experts = len(expert_indices)
|
||||
|
||||
# Slot-major activation: [expert0_tokens | expert1_tokens | ...]
|
||||
slot_hidden = torch.cat([
|
||||
hidden_states[expert_token_lists[e]] for e in expert_indices
|
||||
], dim=0) if any(tpe > 0 for tpe in tokens_per_expert) else torch.zeros(0, hidden_size, dtype=torch.bfloat16, device=device)
|
||||
|
||||
num_slots = slot_hidden.shape[0]
|
||||
if num_slots == 0:
|
||||
return torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16, device=device)
|
||||
|
||||
expert_offsets = compute_expert_offsets(tokens_per_expert, num_experts)
|
||||
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
# L1: gate + up projection (NVFP4 × NVFP4 → BF16)
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
|
||||
# Quantize activation to NVFP4
|
||||
x_fp4, x_sf, x_igs = stage_activation(slot_hidden)
|
||||
|
||||
# Stack L1 weights and convert to K-major
|
||||
l1_mat_b = make_b_k_major(torch.stack(weights['l1_fp4']))
|
||||
|
||||
# Assemble scales
|
||||
x_sf_parts = []
|
||||
offset = 0
|
||||
for tpe in tokens_per_expert:
|
||||
x_sf_parts.append(x_sf[offset:offset+tpe])
|
||||
offset += tpe
|
||||
l1_scale_a = assemble_scales_2d_side(x_sf_parts)
|
||||
l1_scale_b = assemble_scales_3d_side(weights['l1_sf'])
|
||||
|
||||
# Global scales: alpha = igs * weight_gs for each expert
|
||||
l1_global_scale_a = torch.tensor([x_igs] * num_experts, dtype=torch.float32, device=device)
|
||||
l1_global_scale_b = torch.tensor(weights['l1_gs'], dtype=torch.float32, device=device)
|
||||
|
||||
# Run L1 GEMM
|
||||
l1_out = run_nvfp4_grouped_gemm(
|
||||
mat_a=x_fp4, mat_b=l1_mat_b,
|
||||
scale_a=l1_scale_a, scale_b=l1_scale_b,
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=l1_global_scale_a, global_scale_b=l1_global_scale_b,
|
||||
) # (num_slots, intermediate) BF16
|
||||
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
# SiLU(gate) * up (BF16 — nonlinear requires BF16)
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
intermediate = l1_out.shape[1]
|
||||
half = intermediate // 2 # 3072
|
||||
gate = l1_out[:, :half]
|
||||
up = l1_out[:, half:]
|
||||
activated = torch.nn.functional.silu(gate) * up # (num_slots, half) BF16
|
||||
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
# L2: down projection (NVFP4 × NVFP4 → BF16)
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
|
||||
# Re-quantize activated → NVFP4
|
||||
l2_x_fp4, l2_x_sf, l2_x_igs = stage_activation(activated)
|
||||
|
||||
# Stack L2 weights
|
||||
l2_mat_b = make_b_k_major(torch.stack(weights['l2_fp4']))
|
||||
|
||||
# Assemble L2 scales
|
||||
l2_sf_parts = []
|
||||
offset = 0
|
||||
for tpe in tokens_per_expert:
|
||||
l2_sf_parts.append(l2_x_sf[offset:offset+tpe])
|
||||
offset += tpe
|
||||
l2_scale_a = assemble_scales_2d_side(l2_sf_parts)
|
||||
l2_scale_b = assemble_scales_3d_side(weights['l2_sf'])
|
||||
|
||||
# Global scales
|
||||
l2_global_scale_a = torch.tensor([l2_x_igs] * num_experts, dtype=torch.float32, device=device)
|
||||
l2_global_scale_b = torch.tensor(weights['l2_gs'], dtype=torch.float32, device=device)
|
||||
|
||||
# Run L2 GEMM
|
||||
l2_out = run_nvfp4_grouped_gemm(
|
||||
mat_a=l2_x_fp4, mat_b=l2_mat_b,
|
||||
scale_a=l2_scale_a, scale_b=l2_scale_b,
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=l2_global_scale_a, global_scale_b=l2_global_scale_b,
|
||||
) # (num_slots, hidden_size) BF16
|
||||
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
# Scatter with routing weights → final output
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
y = torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16, device=device)
|
||||
|
||||
slot_idx = 0
|
||||
for e in expert_indices:
|
||||
for t in expert_token_lists[e]:
|
||||
# Find which top-k slot this is for this token
|
||||
for k in range(top_k):
|
||||
if expert_ids[t, k].item() == e:
|
||||
w = expert_weights[t, k].item()
|
||||
y[t] += w * l2_out[slot_idx]
|
||||
break
|
||||
slot_idx += 1
|
||||
|
||||
return y
|
||||
@@ -29,6 +29,12 @@ from cutedsl.bridge import (
|
||||
run_nvfp4_grouped_gemm,
|
||||
)
|
||||
|
||||
from cutedsl.moe_pipeline import (
|
||||
stage_activation,
|
||||
prepare_nvfp4_moe_weights,
|
||||
run_nvfp4_moe,
|
||||
)
|
||||
|
||||
# ── Constants ──────────────────────────────────────────────────────────
|
||||
|
||||
NVFP4_MODEL_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||||
@@ -231,76 +237,56 @@ def main():
|
||||
print("=" * 70)
|
||||
|
||||
nvfp4_tensors = load_layer_tensors(NVFP4_MODEL_DIR, LAYER_IDX)
|
||||
expert_keys = [k for k in sorted(nvfp4_tensors.keys()) if 'experts.0.' in k and LAYER_IDX == 0]
|
||||
expert_keys = [k for k in sorted(nvfp4_tensors.keys()) if 'experts.0.' in k]
|
||||
print(f" {len(nvfp4_tensors)} tensors loaded")
|
||||
for key in expert_keys[:5]:
|
||||
for key in expert_keys[:3]:
|
||||
t = nvfp4_tensors[key]
|
||||
print(f" {key}: dtype={t.dtype} shape={tuple(t.shape)}")
|
||||
|
||||
# ── Prepare NVFP4 weights ──
|
||||
print("
|
||||
Preparing NVFP4 weights (dequant → re-quant)...")
|
||||
weights = prepare_nvfp4_moe_weights(nvfp4_tensors, LAYER_IDX, expert_indices)
|
||||
print(f" L1: {len(weights['l1_fp4'])} experts, shape {weights['l1_fp4'][0].shape}")
|
||||
print(f" L2: {len(weights['l2_fp4'])} experts, shape {weights['l2_fp4'][0].shape}")
|
||||
|
||||
# ── Dequantize → BF16 reference ──
|
||||
print("\n Dequantizing NVFP4 → BF16...")
|
||||
print("
|
||||
Dequantizing NVFP4 → BF16 reference...")
|
||||
nvfp4_experts_bf16 = dequantize_nvfp4_experts(nvfp4_tensors, LAYER_IDX, expert_indices)
|
||||
for e in expert_indices[:2]:
|
||||
for proj, w in nvfp4_experts_bf16[e].items():
|
||||
print(f" Expert {e} {proj}: shape={tuple(w.shape)} amax={w.abs().max():.4f}")
|
||||
|
||||
# ── Create test input ──
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0
|
||||
expert_ids = torch.tensor([[0, 1]] * num_tokens, dtype=torch.int32, device=DEVICE)
|
||||
expert_weights = torch.tensor([[0.6, 0.4]] * num_tokens, dtype=torch.float32, device=DEVICE)
|
||||
|
||||
# ── Build slot-based layout for grouped GEMM ──
|
||||
# The kernel expects activation laid out as [expert_0_tokens | expert_1_tokens | ...]
|
||||
# Each token can appear in multiple experts (top-k routing)
|
||||
num_slots = num_tokens * top_k
|
||||
slot_expert = expert_ids.flatten() # (num_slots,)
|
||||
|
||||
# Build per-expert token lists
|
||||
expert_token_lists = {e: [] for e in expert_indices}
|
||||
for t in range(num_tokens):
|
||||
for k in range(top_k):
|
||||
e = expert_ids[t, k].item()
|
||||
expert_token_lists[e].append(t)
|
||||
|
||||
tokens_per_expert = [len(expert_token_lists[e]) for e in expert_indices]
|
||||
|
||||
# Build slot-major activation: concat tokens for each expert
|
||||
slot_hidden = torch.cat([
|
||||
hidden_states[expert_token_lists[e]] for e in expert_indices
|
||||
], dim=0) # (num_slots, hidden_size)
|
||||
|
||||
expert_offsets = compute_expert_offsets(tokens_per_expert, len(expert_indices))
|
||||
|
||||
# ── BF16 L1 reference (slot-major, matching kernel output) ──
|
||||
print("\n Running BF16 L1 reference...")
|
||||
ref_l1_parts = []
|
||||
for e in expert_indices:
|
||||
for t in expert_token_lists[e]:
|
||||
gate = hidden_states[t] @ nvfp4_experts_bf16[e]["gate_proj"].T
|
||||
up = hidden_states[t] @ nvfp4_experts_bf16[e]["up_proj"].T
|
||||
ref_l1_parts.append(torch.cat([gate, up]))
|
||||
ref_l1 = torch.cat(ref_l1_parts, dim=0) # (num_slots, 6144)
|
||||
print(f" BF16 L1 ref: amax={ref_l1.abs().max():.4f} mean={ref_l1.float().mean():.6f}")
|
||||
# ── BF16 full MoE reference ──
|
||||
print("
|
||||
Running BF16 MoE reference...")
|
||||
ref_output = moe_forward_bf16(hidden_states, nvfp4_experts_bf16, expert_ids, expert_weights)
|
||||
print(f" BF16 ref: amax={ref_output.abs().max():.4f} mean={ref_output.float().mean():.6f}")
|
||||
|
||||
del nvfp4_experts_bf16
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# ── CuTeDSL NVFP4 L1 kernel ──
|
||||
print("\n Running CuTeDSL NVFP4 L1 kernel (first run compiles, ~1-2 min)...")
|
||||
kernel_l1 = moe_forward_nvfp4_l1_only(slot_hidden, nvfp4_tensors, LAYER_IDX, expert_indices, tokens_per_expert)
|
||||
print(f" Kernel L1: amax={kernel_l1.abs().max():.4f} mean={kernel_l1.float().mean():.6f}")
|
||||
# ── CuTeDSL NVFP4 full MoE pipeline ──
|
||||
print("
|
||||
Running CuTeDSL NVFP4 MoE pipeline (first run compiles, ~1-2 min)...")
|
||||
kernel_output = run_nvfp4_moe(
|
||||
hidden_states, expert_ids, expert_weights,
|
||||
weights, expert_indices,
|
||||
)
|
||||
print(f" Kernel: amax={kernel_output.abs().max():.4f} mean={kernel_output.float().mean():.6f}")
|
||||
|
||||
# ── Compare ──
|
||||
ref_flat = ref_l1.flatten()
|
||||
kernel_flat = kernel_l1.flatten()
|
||||
|
||||
cosine = torch.nn.functional.cosine_similarity(
|
||||
kernel_flat.unsqueeze(0).float(),
|
||||
ref_flat.unsqueeze(0).float(),
|
||||
kernel_output.flatten().unsqueeze(0).float(),
|
||||
ref_output.flatten().unsqueeze(0).float(),
|
||||
).item()
|
||||
mse = (kernel_flat.float() - ref_flat.float()).pow(2).mean().item()
|
||||
mse = (kernel_output.float() - ref_output.float()).pow(2).mean().item()
|
||||
|
||||
print(f"\n{'=' * 70}")
|
||||
print(f"
|
||||
{'=' * 70}")
|
||||
print(f" RESULT: cosine={cosine:.6f} MSE={mse:.6e}")
|
||||
print(f"{'=' * 70}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user