From bba3bca4d31df3279ec8e9f9c194c09a730e4935 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 19 May 2026 01:56:46 +0000 Subject: [PATCH] Add torch.compile + custom op integration test --- tests/test_compile_custom_op.py | 180 ++++++++++++++++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 tests/test_compile_custom_op.py diff --git a/tests/test_compile_custom_op.py b/tests/test_compile_custom_op.py new file mode 100644 index 00000000..d6be877e --- /dev/null +++ b/tests/test_compile_custom_op.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +"""Test torch.compile + CuTeDSL NVFP4 runner via custom_op. + +Critical test: does torch.compile (fullgraph mode) accept the +nvfp4::moe_gemm custom op and produce a working compiled graph? + +Run on the B200: + docker run --rm --gpus all --entrypoint python3 \ + -v /root/nvfp4-megamoe-kernel:/root/nvfp4-megamoe-kernel \ + -v /root/nvidia-meeting:/root/nvidia-meeting:ro \ + nvfp4-megamoe-kernel-vllm:latest \ + /root/nvfp4-megamoe-kernel/tests/test_compile_custom_op.py +""" +import os +import sys +import json +import glob +import torch +from safetensors import safe_open + +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, REPO_ROOT) + +from cutedsl.runner import CuTeDSLMoERunner +from cutedsl.custom_ops import register_runner, nvfp4_moe_gemm + +NVFP4_MODEL_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" +DEVICE = "cuda" + + +def find_shards(model_dir): + index_path = os.path.join(model_dir, "model.safetensors.index.json") + key_to_shard = {} + if os.path.exists(index_path): + with open(index_path) as f: + index = json.load(f) + for key, shard in index["weight_map"].items(): + key_to_shard[key] = os.path.join(model_dir, shard) + else: + for sf in glob.glob(os.path.join(model_dir, "*.safetensors")): + with safe_open(sf, framework="pt") as f: + for key in f.keys(): + key_to_shard[key] = sf + return key_to_shard + + +def load_layer_tensors(model_dir, layer_idx): + key_to_shard = find_shards(model_dir) + layer_prefix = f"layers.{layer_idx}." + shard_to_keys = {} + for key, shard in key_to_shard.items(): + norm_key = key.removeprefix("model.") + if not norm_key.startswith(layer_prefix): + continue + shard_to_keys.setdefault(shard, []).append((key, norm_key)) + tensors = {} + for shard, keys in shard_to_keys.items(): + with safe_open(shard, framework="pt") as f: + for orig_key, norm_key in keys: + tensors[norm_key] = f.get_tensor(orig_key) + return tensors + + +def prepare_nvfp4_weights_direct(nvfp4_tensors, layer_idx, expert_indices, intermediate_size): + from cutedsl.bridge import quantize_activation_nvfp4, quantize_weight_to_nvfp4 + l1_fp4, l1_sf, l1_gs = [], [], [] + l2_fp4, l2_sf, l2_gs = [], [], [] + + for e in expert_indices: + gate_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE) + up_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"].to(DEVICE) + gate_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE) + up_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE) + gate_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"].item() + up_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"].item() + + fused_w = torch.cat([gate_w, up_w], dim=0) + fused_w_fp4 = fused_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() + fused_sf = torch.cat([gate_sf, up_sf], dim=0).permute(1, 0).contiguous() + l1_max_gs = max(gate_gs, up_gs) + if gate_gs != up_gs: + fused_sf_f32 = fused_sf.float() + fused_sf_f32[:, :intermediate_size] *= (gate_gs / l1_max_gs) + fused_sf_f32[:, intermediate_size:] *= (up_gs / l1_max_gs) + fused_sf = fused_sf_f32.to(torch.float8_e4m3fn) + l1_fp4.append(fused_w_fp4) + l1_sf.append(fused_sf) + l1_gs.append(l1_max_gs) + + down_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight" + if down_key in nvfp4_tensors: + down_w = nvfp4_tensors[down_key].to(DEVICE) + down_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE) + down_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"].item() + l2_fp4.append(down_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()) + l2_sf.append(down_sf.permute(1, 0).contiguous()) + l2_gs.append(down_gs) + + return { + 'l1_fp4': l1_fp4, 'l1_sf': l1_sf, 'l1_gs': l1_gs, + 'l2_fp4': l2_fp4, 'l2_sf': l2_sf, 'l2_gs': l2_gs, + } + + +def main(): + torch.manual_seed(42) + expert_indices = [0, 1, 2] + hidden_size = 7168 + intermediate_size = 3072 + + print("=" * 70) + print(" torch.compile + CuTeDSL Custom Op Test") + print("=" * 70) + + # Load weights + nvfp4_tensors = load_layer_tensors(NVFP4_MODEL_DIR, 0) + weights = prepare_nvfp4_weights_direct(nvfp4_tensors, 0, expert_indices, intermediate_size) + + # Create runner + runner = CuTeDSLMoERunner( + num_experts=len(expert_indices), + hidden_size=hidden_size, + intermediate_size=intermediate_size, + max_num_tokens=8, + top_k=2, + device="cuda", + ) + runner.prepare_weights_direct( + weights['l1_fp4'], weights['l1_sf'], weights['l1_gs'], + weights['l2_fp4'], weights['l2_sf'], weights['l2_gs'], + ) + runner_id = register_runner(runner) + + # Test input + hidden_states = torch.randn(4, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0 + topk_ids = torch.tensor([[0, 1]] * 4, dtype=torch.int32, device=DEVICE) + topk_weights = torch.tensor([[0.6, 0.4]] * 4, dtype=torch.float32, device=DEVICE) + + # 1. Eager mode (baseline) + print("\n[1/2] Running eager mode (baseline)...") + runner._ensure_stacked() + eager_out = nvfp4_moe_gemm(hidden_states, topk_weights, topk_ids, runner_id, hidden_size) + print(f" Eager output: amax={eager_out.abs().max():.4f} mean={eager_out.float().mean():.6f}") + + # 2. torch.compile fullgraph + print("\n[2/2] Running torch.compile(fullgraph=True)...") + try: + @torch.compile(fullgraph=True) + def compiled_fn(hs, tw, ti): + return nvfp4_moe_gemm(hs, tw, ti, runner_id, hidden_size) + + compiled_out = compiled_fn(hidden_states, topk_weights, topk_ids) + print(f" Compiled output: amax={compiled_out.abs().max():.4f} mean={compiled_out.float().mean():.6f}") + + # Compare + if eager_out.shape == compiled_out.shape: + cos = torch.nn.functional.cosine_similarity( + eager_out.flatten().unsqueeze(0).float(), + compiled_out.flatten().unsqueeze(0).float(), + ).item() + print(f"\n Eager vs Compiled: cosine={cos:.6f}") + if cos > 0.99: + print(" ✅ torch.compile produces matching output!") + else: + print(f" ⚠️ Cosine {cos:.4f} < 0.99 — check for numerical issues") + else: + print(f" ❌ Shape mismatch: eager={eager_out.shape} compiled={compiled_out.shape}") + except Exception as e: + print(f" ❌ torch.compile FAILED: {type(e).__name__}: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + + print("\n" + "=" * 70) + print(" Test complete ✅") + print("=" * 70) + + +if __name__ == "__main__": + main()