diff --git a/cutedsl/moe_pipeline.py b/cutedsl/moe_pipeline.py new file mode 100644 index 00000000..52bbf6a8 --- /dev/null +++ b/cutedsl/moe_pipeline.py @@ -0,0 +1,255 @@ +""" +Full NVFP4 MoE pipeline using CuTeDSL ScaledGroupedGemmKernel. + +Data flow (NVFP4-native, BF16 only where required): + 1. BF16 hidden_states → quantize to NVFP4 (stage_activation) + 2. L1 GEMM: NVFP4 × NVFP4 → BF16 output (gate+up) + 3. SiLU(gate) * up → BF16 activated (nonlinear requires BF16) + 4. Re-quantize activated → NVFP4 (stage_activation) + 5. L2 GEMM: NVFP4 × NVFP4 → BF16 output (down_proj) + 6. Scatter with routing weights → BF16 output + +Both GEMMs are fully NVFP4: A in float4_e2m1fn_x2, B in float4_e2m1fn_x2, +block scales in float8_e4m3fn, global scales in float32. +""" +import torch + +from cutedsl.bridge import ( + quantize_to_nvfp4, + quantize_weight_to_nvfp4, + assemble_scales_2d_side, + assemble_scales_3d_side, + make_b_k_major, + compute_expert_offsets, + run_nvfp4_grouped_gemm, +) + + +def stage_activation(x_bf16): + """Quantize BF16 activation to NVFP4. + + This is the NVFP4-native equivalent of the old stage_activation. + Keeps data in FP4 as long as possible — only leaves NVFP4 for nonlinear ops. + + Returns (x_fp4, x_sf, global_scale) where: + x_fp4: float4_e2m1fn_x2 (native PyTorch FP4) + x_sf: float8_e4m3fn block scales + global_scale: float32 scalar + """ + return quantize_to_nvfp4(x_bf16) + + +def quantize_weight(w_bf16): + """Quantize BF16 weight to NVFP4. + + Weight is (K, N) where K is the input/hidden dim (packed dimension). + + Returns (w_fp4, w_sf, global_scale). + """ + return quantize_weight_to_nvfp4(w_bf16) + + +def prepare_nvfp4_moe_weights(nvfp4_tensors, layer_idx, expert_indices): + """Load NVFP4 checkpoint weights and prepare for the grouped GEMM. + + Dequantizes checkpoint NVFP4 → BF16 → re-quantizes to our native format. + This round-trip ensures our FP4 packing convention matches the kernel. + + Future optimization: load checkpoint FP4 bytes directly into + float4_e2m1fn_x2 tensors without the BF16 round-trip. + + Returns dict with l1 and l2 weight info per expert. + """ + from tests.layertest import dequantize_nvfp4_weight, DEVICE + + l1_weights = [] # gate+up fused, (K, N) = (hidden, intermediate) + l2_weights = [] # down, (K, N) = (intermediate, hidden) + + for e in expert_indices: + # L1: gate + up + gate_w_bf16 = dequantize_nvfp4_weight( + nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE), + nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE), + nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"].item(), + ) + up_w_bf16 = dequantize_nvfp4_weight( + nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"].to(DEVICE), + nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE), + nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"].item(), + ) + + # Fuse gate+up: (6144, 7168) → transpose to (7168, 6144) for weight quantization + fused_l1 = torch.cat([gate_w_bf16, up_w_bf16], dim=0) # (6144, 7168) + l1_w_bf16 = fused_l1.T # (7168, 6144) — K=7168, N=6144 + l1_weights.append(l1_w_bf16) + + # L2: down + down_w_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight" + if down_w_key in nvfp4_tensors: + down_w_bf16 = dequantize_nvfp4_weight( + nvfp4_tensors[down_w_key].to(DEVICE), + nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE), + nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"].item(), + ) + # down_proj is (7168, 3072) → transpose to (3072, 7168) for K=intermediate + l2_w_bf16 = down_w_bf16.T # (3072, 7168) — K=3072, N=7168 + else: + # Expert 211 has no down_proj + l2_w_bf16 = torch.zeros(3072, 7168, dtype=torch.bfloat16, device=DEVICE) + + l2_weights.append(l2_w_bf16) + + # Quantize all weights to NVFP4 + l1_fp4, l1_sf, l1_gs = [], [], [] + l2_fp4, l2_sf, l2_gs = [], [], [] + + for l1_w, l2_w in zip(l1_weights, l2_weights): + w_fp4, w_sf, w_gs = quantize_weight(l1_w) + l1_fp4.append(w_fp4) + l1_sf.append(w_sf) + l1_gs.append(w_gs) + + w_fp4, w_sf, w_gs = quantize_weight(l2_w) + l2_fp4.append(w_fp4) + l2_sf.append(w_sf) + l2_gs.append(w_gs) + + return { + 'l1_fp4': l1_fp4, 'l1_sf': l1_sf, 'l1_gs': l1_gs, + 'l2_fp4': l2_fp4, 'l2_sf': l2_sf, 'l2_gs': l2_gs, + } + + +def run_nvfp4_moe( + hidden_states, # (num_tokens, hidden_size) BF16 + expert_ids, # (num_tokens, top_k) int32 + expert_weights, # (num_tokens, top_k) float32 + weights, # dict from prepare_nvfp4_moe_weights + expert_indices, # list of expert IDs +): + """Run the full NVFP4 MoE forward pass. + + NVFP4-native pipeline: + 1. Quantize activation → NVFP4 + 2. L1 GEMM (NVFP4 × NVFP4 → BF16) + 3. SiLU(gate) * up (BF16 — nonlinear requires BF16) + 4. Re-quantize → NVFP4 + 5. L2 GEMM (NVFP4 × NVFP4 → BF16) + 6. Scatter with routing weights → BF16 + + Returns: (num_tokens, hidden_size) BF16 + """ + num_tokens, hidden_size = hidden_states.shape + top_k = expert_ids.shape[1] + device = hidden_states.device + + # ── Build slot-based routing ── + expert_token_lists = {e: [] for e in expert_indices} + for t in range(num_tokens): + for k in range(top_k): + e = expert_ids[t, k].item() + if e in expert_token_lists: + expert_token_lists[e].append(t) + + tokens_per_expert = [len(expert_token_lists[e]) for e in expert_indices] + num_experts = len(expert_indices) + + # Slot-major activation: [expert0_tokens | expert1_tokens | ...] + slot_hidden = torch.cat([ + hidden_states[expert_token_lists[e]] for e in expert_indices + ], dim=0) if any(tpe > 0 for tpe in tokens_per_expert) else torch.zeros(0, hidden_size, dtype=torch.bfloat16, device=device) + + num_slots = slot_hidden.shape[0] + if num_slots == 0: + return torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16, device=device) + + expert_offsets = compute_expert_offsets(tokens_per_expert, num_experts) + + # ════════════════════════════════════════════════════════════════ + # L1: gate + up projection (NVFP4 × NVFP4 → BF16) + # ════════════════════════════════════════════════════════════════ + + # Quantize activation to NVFP4 + x_fp4, x_sf, x_igs = stage_activation(slot_hidden) + + # Stack L1 weights and convert to K-major + l1_mat_b = make_b_k_major(torch.stack(weights['l1_fp4'])) + + # Assemble scales + x_sf_parts = [] + offset = 0 + for tpe in tokens_per_expert: + x_sf_parts.append(x_sf[offset:offset+tpe]) + offset += tpe + l1_scale_a = assemble_scales_2d_side(x_sf_parts) + l1_scale_b = assemble_scales_3d_side(weights['l1_sf']) + + # Global scales: alpha = igs * weight_gs for each expert + l1_global_scale_a = torch.tensor([x_igs] * num_experts, dtype=torch.float32, device=device) + l1_global_scale_b = torch.tensor(weights['l1_gs'], dtype=torch.float32, device=device) + + # Run L1 GEMM + l1_out = run_nvfp4_grouped_gemm( + mat_a=x_fp4, mat_b=l1_mat_b, + scale_a=l1_scale_a, scale_b=l1_scale_b, + expert_offsets=expert_offsets, + global_scale_a=l1_global_scale_a, global_scale_b=l1_global_scale_b, + ) # (num_slots, intermediate) BF16 + + # ════════════════════════════════════════════════════════════════ + # SiLU(gate) * up (BF16 — nonlinear requires BF16) + # ════════════════════════════════════════════════════════════════ + intermediate = l1_out.shape[1] + half = intermediate // 2 # 3072 + gate = l1_out[:, :half] + up = l1_out[:, half:] + activated = torch.nn.functional.silu(gate) * up # (num_slots, half) BF16 + + # ════════════════════════════════════════════════════════════════ + # L2: down projection (NVFP4 × NVFP4 → BF16) + # ════════════════════════════════════════════════════════════════ + + # Re-quantize activated → NVFP4 + l2_x_fp4, l2_x_sf, l2_x_igs = stage_activation(activated) + + # Stack L2 weights + l2_mat_b = make_b_k_major(torch.stack(weights['l2_fp4'])) + + # Assemble L2 scales + l2_sf_parts = [] + offset = 0 + for tpe in tokens_per_expert: + l2_sf_parts.append(l2_x_sf[offset:offset+tpe]) + offset += tpe + l2_scale_a = assemble_scales_2d_side(l2_sf_parts) + l2_scale_b = assemble_scales_3d_side(weights['l2_sf']) + + # Global scales + l2_global_scale_a = torch.tensor([l2_x_igs] * num_experts, dtype=torch.float32, device=device) + l2_global_scale_b = torch.tensor(weights['l2_gs'], dtype=torch.float32, device=device) + + # Run L2 GEMM + l2_out = run_nvfp4_grouped_gemm( + mat_a=l2_x_fp4, mat_b=l2_mat_b, + scale_a=l2_scale_a, scale_b=l2_scale_b, + expert_offsets=expert_offsets, + global_scale_a=l2_global_scale_a, global_scale_b=l2_global_scale_b, + ) # (num_slots, hidden_size) BF16 + + # ════════════════════════════════════════════════════════════════ + # Scatter with routing weights → final output + # ════════════════════════════════════════════════════════════════ + y = torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16, device=device) + + slot_idx = 0 + for e in expert_indices: + for t in expert_token_lists[e]: + # Find which top-k slot this is for this token + for k in range(top_k): + if expert_ids[t, k].item() == e: + w = expert_weights[t, k].item() + y[t] += w * l2_out[slot_idx] + break + slot_idx += 1 + + return y diff --git a/tests/layertest.py b/tests/layertest.py index 53932a54..716e8291 100644 --- a/tests/layertest.py +++ b/tests/layertest.py @@ -29,6 +29,12 @@ from cutedsl.bridge import ( run_nvfp4_grouped_gemm, ) +from cutedsl.moe_pipeline import ( + stage_activation, + prepare_nvfp4_moe_weights, + run_nvfp4_moe, +) + # ── Constants ────────────────────────────────────────────────────────── NVFP4_MODEL_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" @@ -231,76 +237,56 @@ def main(): print("=" * 70) nvfp4_tensors = load_layer_tensors(NVFP4_MODEL_DIR, LAYER_IDX) - expert_keys = [k for k in sorted(nvfp4_tensors.keys()) if 'experts.0.' in k and LAYER_IDX == 0] + expert_keys = [k for k in sorted(nvfp4_tensors.keys()) if 'experts.0.' in k] print(f" {len(nvfp4_tensors)} tensors loaded") - for key in expert_keys[:5]: + for key in expert_keys[:3]: t = nvfp4_tensors[key] print(f" {key}: dtype={t.dtype} shape={tuple(t.shape)}") + # ── Prepare NVFP4 weights ── + print(" + Preparing NVFP4 weights (dequant → re-quant)...") + weights = prepare_nvfp4_moe_weights(nvfp4_tensors, LAYER_IDX, expert_indices) + print(f" L1: {len(weights['l1_fp4'])} experts, shape {weights['l1_fp4'][0].shape}") + print(f" L2: {len(weights['l2_fp4'])} experts, shape {weights['l2_fp4'][0].shape}") + # ── Dequantize → BF16 reference ── - print("\n Dequantizing NVFP4 → BF16...") + print(" + Dequantizing NVFP4 → BF16 reference...") nvfp4_experts_bf16 = dequantize_nvfp4_experts(nvfp4_tensors, LAYER_IDX, expert_indices) - for e in expert_indices[:2]: - for proj, w in nvfp4_experts_bf16[e].items(): - print(f" Expert {e} {proj}: shape={tuple(w.shape)} amax={w.abs().max():.4f}") # ── Create test input ── hidden_states = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0 expert_ids = torch.tensor([[0, 1]] * num_tokens, dtype=torch.int32, device=DEVICE) expert_weights = torch.tensor([[0.6, 0.4]] * num_tokens, dtype=torch.float32, device=DEVICE) - # ── Build slot-based layout for grouped GEMM ── - # The kernel expects activation laid out as [expert_0_tokens | expert_1_tokens | ...] - # Each token can appear in multiple experts (top-k routing) - num_slots = num_tokens * top_k - slot_expert = expert_ids.flatten() # (num_slots,) - - # Build per-expert token lists - expert_token_lists = {e: [] for e in expert_indices} - for t in range(num_tokens): - for k in range(top_k): - e = expert_ids[t, k].item() - expert_token_lists[e].append(t) - - tokens_per_expert = [len(expert_token_lists[e]) for e in expert_indices] - - # Build slot-major activation: concat tokens for each expert - slot_hidden = torch.cat([ - hidden_states[expert_token_lists[e]] for e in expert_indices - ], dim=0) # (num_slots, hidden_size) - - expert_offsets = compute_expert_offsets(tokens_per_expert, len(expert_indices)) - - # ── BF16 L1 reference (slot-major, matching kernel output) ── - print("\n Running BF16 L1 reference...") - ref_l1_parts = [] - for e in expert_indices: - for t in expert_token_lists[e]: - gate = hidden_states[t] @ nvfp4_experts_bf16[e]["gate_proj"].T - up = hidden_states[t] @ nvfp4_experts_bf16[e]["up_proj"].T - ref_l1_parts.append(torch.cat([gate, up])) - ref_l1 = torch.cat(ref_l1_parts, dim=0) # (num_slots, 6144) - print(f" BF16 L1 ref: amax={ref_l1.abs().max():.4f} mean={ref_l1.float().mean():.6f}") + # ── BF16 full MoE reference ── + print(" + Running BF16 MoE reference...") + ref_output = moe_forward_bf16(hidden_states, nvfp4_experts_bf16, expert_ids, expert_weights) + print(f" BF16 ref: amax={ref_output.abs().max():.4f} mean={ref_output.float().mean():.6f}") del nvfp4_experts_bf16 torch.cuda.empty_cache() - # ── CuTeDSL NVFP4 L1 kernel ── - print("\n Running CuTeDSL NVFP4 L1 kernel (first run compiles, ~1-2 min)...") - kernel_l1 = moe_forward_nvfp4_l1_only(slot_hidden, nvfp4_tensors, LAYER_IDX, expert_indices, tokens_per_expert) - print(f" Kernel L1: amax={kernel_l1.abs().max():.4f} mean={kernel_l1.float().mean():.6f}") + # ── CuTeDSL NVFP4 full MoE pipeline ── + print(" + Running CuTeDSL NVFP4 MoE pipeline (first run compiles, ~1-2 min)...") + kernel_output = run_nvfp4_moe( + hidden_states, expert_ids, expert_weights, + weights, expert_indices, + ) + print(f" Kernel: amax={kernel_output.abs().max():.4f} mean={kernel_output.float().mean():.6f}") # ── Compare ── - ref_flat = ref_l1.flatten() - kernel_flat = kernel_l1.flatten() - cosine = torch.nn.functional.cosine_similarity( - kernel_flat.unsqueeze(0).float(), - ref_flat.unsqueeze(0).float(), + kernel_output.flatten().unsqueeze(0).float(), + ref_output.flatten().unsqueeze(0).float(), ).item() - mse = (kernel_flat.float() - ref_flat.float()).pow(2).mean().item() + mse = (kernel_output.float() - ref_output.float()).pow(2).mean().item() - print(f"\n{'=' * 70}") + print(f" +{'=' * 70}") print(f" RESULT: cosine={cosine:.6f} MSE={mse:.6e}") print(f"{'=' * 70}")