Add torch.compile + custom op integration test
This commit is contained in:
180
tests/test_compile_custom_op.py
Normal file
180
tests/test_compile_custom_op.py
Normal file
@@ -0,0 +1,180 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test torch.compile + CuTeDSL NVFP4 runner via custom_op.
|
||||
|
||||
Critical test: does torch.compile (fullgraph mode) accept the
|
||||
nvfp4::moe_gemm custom op and produce a working compiled graph?
|
||||
|
||||
Run on the B200:
|
||||
docker run --rm --gpus all --entrypoint python3 \
|
||||
-v /root/nvfp4-megamoe-kernel:/root/nvfp4-megamoe-kernel \
|
||||
-v /root/nvidia-meeting:/root/nvidia-meeting:ro \
|
||||
nvfp4-megamoe-kernel-vllm:latest \
|
||||
/root/nvfp4-megamoe-kernel/tests/test_compile_custom_op.py
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import glob
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
|
||||
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, REPO_ROOT)
|
||||
|
||||
from cutedsl.runner import CuTeDSLMoERunner
|
||||
from cutedsl.custom_ops import register_runner, nvfp4_moe_gemm
|
||||
|
||||
NVFP4_MODEL_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||||
DEVICE = "cuda"
|
||||
|
||||
|
||||
def find_shards(model_dir):
|
||||
index_path = os.path.join(model_dir, "model.safetensors.index.json")
|
||||
key_to_shard = {}
|
||||
if os.path.exists(index_path):
|
||||
with open(index_path) as f:
|
||||
index = json.load(f)
|
||||
for key, shard in index["weight_map"].items():
|
||||
key_to_shard[key] = os.path.join(model_dir, shard)
|
||||
else:
|
||||
for sf in glob.glob(os.path.join(model_dir, "*.safetensors")):
|
||||
with safe_open(sf, framework="pt") as f:
|
||||
for key in f.keys():
|
||||
key_to_shard[key] = sf
|
||||
return key_to_shard
|
||||
|
||||
|
||||
def load_layer_tensors(model_dir, layer_idx):
|
||||
key_to_shard = find_shards(model_dir)
|
||||
layer_prefix = f"layers.{layer_idx}."
|
||||
shard_to_keys = {}
|
||||
for key, shard in key_to_shard.items():
|
||||
norm_key = key.removeprefix("model.")
|
||||
if not norm_key.startswith(layer_prefix):
|
||||
continue
|
||||
shard_to_keys.setdefault(shard, []).append((key, norm_key))
|
||||
tensors = {}
|
||||
for shard, keys in shard_to_keys.items():
|
||||
with safe_open(shard, framework="pt") as f:
|
||||
for orig_key, norm_key in keys:
|
||||
tensors[norm_key] = f.get_tensor(orig_key)
|
||||
return tensors
|
||||
|
||||
|
||||
def prepare_nvfp4_weights_direct(nvfp4_tensors, layer_idx, expert_indices, intermediate_size):
|
||||
from cutedsl.bridge import quantize_activation_nvfp4, quantize_weight_to_nvfp4
|
||||
l1_fp4, l1_sf, l1_gs = [], [], []
|
||||
l2_fp4, l2_sf, l2_gs = [], [], []
|
||||
|
||||
for e in expert_indices:
|
||||
gate_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE)
|
||||
up_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"].to(DEVICE)
|
||||
gate_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE)
|
||||
up_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE)
|
||||
gate_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"].item()
|
||||
up_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"].item()
|
||||
|
||||
fused_w = torch.cat([gate_w, up_w], dim=0)
|
||||
fused_w_fp4 = fused_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
|
||||
fused_sf = torch.cat([gate_sf, up_sf], dim=0).permute(1, 0).contiguous()
|
||||
l1_max_gs = max(gate_gs, up_gs)
|
||||
if gate_gs != up_gs:
|
||||
fused_sf_f32 = fused_sf.float()
|
||||
fused_sf_f32[:, :intermediate_size] *= (gate_gs / l1_max_gs)
|
||||
fused_sf_f32[:, intermediate_size:] *= (up_gs / l1_max_gs)
|
||||
fused_sf = fused_sf_f32.to(torch.float8_e4m3fn)
|
||||
l1_fp4.append(fused_w_fp4)
|
||||
l1_sf.append(fused_sf)
|
||||
l1_gs.append(l1_max_gs)
|
||||
|
||||
down_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight"
|
||||
if down_key in nvfp4_tensors:
|
||||
down_w = nvfp4_tensors[down_key].to(DEVICE)
|
||||
down_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE)
|
||||
down_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"].item()
|
||||
l2_fp4.append(down_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous())
|
||||
l2_sf.append(down_sf.permute(1, 0).contiguous())
|
||||
l2_gs.append(down_gs)
|
||||
|
||||
return {
|
||||
'l1_fp4': l1_fp4, 'l1_sf': l1_sf, 'l1_gs': l1_gs,
|
||||
'l2_fp4': l2_fp4, 'l2_sf': l2_sf, 'l2_gs': l2_gs,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
torch.manual_seed(42)
|
||||
expert_indices = [0, 1, 2]
|
||||
hidden_size = 7168
|
||||
intermediate_size = 3072
|
||||
|
||||
print("=" * 70)
|
||||
print(" torch.compile + CuTeDSL Custom Op Test")
|
||||
print("=" * 70)
|
||||
|
||||
# Load weights
|
||||
nvfp4_tensors = load_layer_tensors(NVFP4_MODEL_DIR, 0)
|
||||
weights = prepare_nvfp4_weights_direct(nvfp4_tensors, 0, expert_indices, intermediate_size)
|
||||
|
||||
# Create runner
|
||||
runner = CuTeDSLMoERunner(
|
||||
num_experts=len(expert_indices),
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
max_num_tokens=8,
|
||||
top_k=2,
|
||||
device="cuda",
|
||||
)
|
||||
runner.prepare_weights_direct(
|
||||
weights['l1_fp4'], weights['l1_sf'], weights['l1_gs'],
|
||||
weights['l2_fp4'], weights['l2_sf'], weights['l2_gs'],
|
||||
)
|
||||
runner_id = register_runner(runner)
|
||||
|
||||
# Test input
|
||||
hidden_states = torch.randn(4, hidden_size, dtype=torch.bfloat16, device=DEVICE) * 2.0
|
||||
topk_ids = torch.tensor([[0, 1]] * 4, dtype=torch.int32, device=DEVICE)
|
||||
topk_weights = torch.tensor([[0.6, 0.4]] * 4, dtype=torch.float32, device=DEVICE)
|
||||
|
||||
# 1. Eager mode (baseline)
|
||||
print("\n[1/2] Running eager mode (baseline)...")
|
||||
runner._ensure_stacked()
|
||||
eager_out = nvfp4_moe_gemm(hidden_states, topk_weights, topk_ids, runner_id, hidden_size)
|
||||
print(f" Eager output: amax={eager_out.abs().max():.4f} mean={eager_out.float().mean():.6f}")
|
||||
|
||||
# 2. torch.compile fullgraph
|
||||
print("\n[2/2] Running torch.compile(fullgraph=True)...")
|
||||
try:
|
||||
@torch.compile(fullgraph=True)
|
||||
def compiled_fn(hs, tw, ti):
|
||||
return nvfp4_moe_gemm(hs, tw, ti, runner_id, hidden_size)
|
||||
|
||||
compiled_out = compiled_fn(hidden_states, topk_weights, topk_ids)
|
||||
print(f" Compiled output: amax={compiled_out.abs().max():.4f} mean={compiled_out.float().mean():.6f}")
|
||||
|
||||
# Compare
|
||||
if eager_out.shape == compiled_out.shape:
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
eager_out.flatten().unsqueeze(0).float(),
|
||||
compiled_out.flatten().unsqueeze(0).float(),
|
||||
).item()
|
||||
print(f"\n Eager vs Compiled: cosine={cos:.6f}")
|
||||
if cos > 0.99:
|
||||
print(" ✅ torch.compile produces matching output!")
|
||||
else:
|
||||
print(f" ⚠️ Cosine {cos:.4f} < 0.99 — check for numerical issues")
|
||||
else:
|
||||
print(f" ❌ Shape mismatch: eager={eager_out.shape} compiled={compiled_out.shape}")
|
||||
except Exception as e:
|
||||
print(f" ❌ torch.compile FAILED: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print(" Test complete ✅")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user