#!/usr/bin/env python3 """Test torch.compile + CuTeDSL NVFP4 runner via custom_op. Critical test: does torch.compile (fullgraph mode) accept the nvfp4::moe_gemm custom op and produce a working compiled graph? Run on the B200: docker run --rm --gpus all --entrypoint python3 \ -v /root/nvfp4-megamoe-kernel:/root/nvfp4-megamoe-kernel \ -v /root/nvidia-meeting:/root/nvidia-meeting:ro \ nvfp4-megamoe-kernel-vllm:latest \ /root/nvfp4-megamoe-kernel/tests/test_compile_custom_op.py """ import os import sys import json import glob import torch from safetensors import safe_open REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, REPO_ROOT) from cutedsl.runner import CuTeDSLMoERunner from cutedsl.custom_ops import register_runner, nvfp4_moe_gemm NVFP4_MODEL_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" DEVICE = "cuda" def find_shards(model_dir): index_path = os.path.join(model_dir, "model.safetensors.index.json") key_to_shard = {} if os.path.exists(index_path): with open(index_path) as f: index = json.load(f) for key, shard in index["weight_map"].items(): key_to_shard[key] = os.path.join(model_dir, shard) else: for sf in glob.glob(os.path.join(model_dir, "*.safetensors")): with safe_open(sf, framework="pt") as f: for key in f.keys(): key_to_shard[key] = sf return key_to_shard def load_layer_tensors(model_dir, layer_idx): key_to_shard = find_shards(model_dir) layer_prefix = f"layers.{layer_idx}." shard_to_keys = {} for key, shard in key_to_shard.items(): norm_key = key.removeprefix("model.") if not norm_key.startswith(layer_prefix): continue shard_to_keys.setdefault(shard, []).append((key, norm_key)) tensors = {} for shard, keys in shard_to_keys.items(): with safe_open(shard, framework="pt") as f: for orig_key, norm_key in keys: tensors[norm_key] = f.get_tensor(orig_key) return tensors def prepare_nvfp4_weights_direct(nvfp4_tensors, layer_idx, expert_indices, intermediate_size): from cutedsl.bridge import quantize_activation_nvfp4, quantize_weight_to_nvfp4 l1_fp4, l1_sf, l1_gs = [], [], [] l2_fp4, l2_sf, l2_gs = [], [], [] for e in expert_indices: gate_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE) up_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"].to(DEVICE) gate_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE) up_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE) gate_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"].item() up_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"].item() fused_w = torch.cat([gate_w, up_w], dim=0) fused_w_fp4 = fused_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() fused_sf = torch.cat([gate_sf, up_sf], dim=0).permute(1, 0).contiguous() l1_max_gs = max(gate_gs, up_gs) if gate_gs != up_gs: fused_sf_f32 = fused_sf.float() fused_sf_f32[:, :intermediate_size] *= (gate_gs / l1_max_gs) fused_sf_f32[:, intermediate_size:] *= (up_gs / l1_max_gs) fused_sf = fused_sf_f32.to(torch.float8_e4m3fn) l1_fp4.append(fused_w_fp4) l1_sf.append(fused_sf) l1_gs.append(l1_max_gs) down_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight" if down_key in nvfp4_tensors: down_w = nvfp4_tensors[down_key].to(DEVICE) down_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE) down_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"].item() l2_fp4.append(down_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()) l2_sf.append(down_sf.permute(1, 0).contiguous()) l2_gs.append(down_gs) return { 'l1_fp4': l1_fp4, 'l1_sf': l1_sf, 'l1_gs': l1_gs, 'l2_fp4': l2_fp4, 'l2_sf': l2_sf, 'l2_gs': l2_gs, } def main(): torch.manual_seed(42) expert_indices = [0, 1, 2] hidden_size = 7168 intermediate_size = 3072 print("=" * 70) print(" torch.compile + CuTeDSL Custom Op Test") print("=" * 70) # Load weights nvfp4_tensors = load_layer_tensors(NVFP4_MODEL_DIR, 0) weights = prepare_nvfp4_weights_direct(nvfp4_tensors, 0, expert_indices, intermediate_size) # Create runner runner = CuTeDSLMoERunner( num_experts=len(expert_indices), hidden_size=hidden_size, intermediate_size=intermediate_size, max_num_tokens=8, top_k=2, device="cuda", ) runner.prepare_weights_direct( weights['l1_fp4'], weights['l1_sf'], weights['l1_gs'], weights['l2_fp4'], weights['l2_sf'], weights['l2_gs'], ) runner_id = register_runner(runner) # Test input hidden_states = torch.randn(4, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0 topk_ids = torch.tensor([[0, 1]] * 4, dtype=torch.int32, device=DEVICE) topk_weights = torch.tensor([[0.6, 0.4]] * 4, dtype=torch.float32, device=DEVICE) # 1. Warmup: compute activation global scales print("\n[0] Computing activation global scales (warmup)...") runner.compute_activation_global_scales(hidden_states, topk_weights, topk_ids) print(f" L1 gs: {runner._l1_activation_global_scale:.6f}") print(f" L2 gs: {runner._l2_activation_global_scale:.6f}") # 1. Eager mode (baseline) print("\n[1/2] Running eager mode (baseline)...") runner._ensure_stacked() eager_out = nvfp4_moe_gemm(hidden_states, topk_weights, topk_ids, runner_id, hidden_size) print(f" Eager output: amax={eager_out.abs().max():.4f} mean={eager_out.float().mean():.6f}") # 2. torch.compile fullgraph print("\n[2/2] Running torch.compile(fullgraph=True)...") try: @torch.compile(fullgraph=True) def compiled_fn(hs, tw, ti): return nvfp4_moe_gemm(hs, tw, ti, runner_id, hidden_size) compiled_out = compiled_fn(hidden_states, topk_weights, topk_ids) print(f" Compiled output: amax={compiled_out.abs().max():.4f} mean={compiled_out.float().mean():.6f}") # Compare if eager_out.shape == compiled_out.shape: cos = torch.nn.functional.cosine_similarity( eager_out.flatten().unsqueeze(0).float(), compiled_out.flatten().unsqueeze(0).float(), ).item() print(f"\n Eager vs Compiled: cosine={cos:.6f}") if cos > 0.99: print(" ✅ torch.compile produces matching output!") else: print(f" ⚠️ Cosine {cos:.4f} < 0.99 — check for numerical issues") else: print(f" ❌ Shape mismatch: eager={eager_out.shape} compiled={compiled_out.shape}") except Exception as e: print(f" ❌ torch.compile FAILED: {type(e).__name__}: {e}") import traceback traceback.print_exc() sys.exit(1) print("\n" + "=" * 70) print(" Test complete ✅") print("=" * 70) if __name__ == "__main__": main()