diff --git a/tests/cudagraph_test.py b/tests/cudagraph_test.py new file mode 100644 index 00000000..091d8b07 --- /dev/null +++ b/tests/cudagraph_test.py @@ -0,0 +1,283 @@ +#!/usr/bin python3 +""" +CUDAGraph compatibility test for CuTeDSL NVFP4 MoE runner. + +Detects CPU-GPU syncs that break cudagraph capture by patching +torch CUDA sync functions. Any sync during the forward pass = FAIL. + +Run on the B200: + cd /root/nvfp4-megamoe-kernel + python3 tests/cudagraph_test.py +""" +import os +import sys +import torch +import contextlib + +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, REPO_ROOT) + +from vllm.nvfp4_cutedsl import CuTeDSLMoERunner + + +class CUDASyncDetector: + """Context manager that detects CPU-GPU syncs. + + Patches torch.cuda.synchronize, Tensor.item(), Tensor.cpu(), + Tensor.numpy(), Tensor.tolist() to raise on call. + + Usage: + with CUDASyncDetector() as detector: + model.forward(x) + print(f"Syncs detected: {detector.sync_count}") + """ + + def __init__(self, allow_compile_syncs=True): + self.sync_count = 0 + self.sync_stacktraces = [] + self.allow_compile_syncs = allow_compile_syncs + self._in_compile = False + self._originals = {} + + def _make_sync_guard(self, name, original_fn): + """Create a guard function that logs sync attempts.""" + def guard(*args, **kwargs): + import traceback + stack = traceback.format_stack() + # Skip the guard frame itself + stack_str = ''.join(stack[-5:-1]) + + self.sync_count += 1 + self.sync_stacktraces.append((name, stack_str)) + + # Still call the original (we want to detect, not block) + return original_fn(*args, **kwargs) + return guard + + def __enter__(self): + # Patch torch.cuda.synchronize + self._originals['torch.cuda.synchronize'] = torch.cuda.synchronize + torch.cuda.synchronize = self._make_sync_guard( + 'torch.cuda.synchronize', torch.cuda.synchronize + ) + + # Patch Tensor.item() + self._originals['Tensor.item'] = torch.Tensor.item + original_item = torch.Tensor.item + def item_guard(self_tensor): + import traceback + stack = traceback.format_stack() + stack_str = ''.join(stack[-5:-1]) + # Allow if called from _ensure_stacked or prepare_weights (init time) + if any(x in stack_str for x in ['_ensure_stacked', 'prepare_weights', 'load_weights', 'warmup']): + return original_item(self_tensor) + CUDASyncDetector_instance = self # closure capture + CUDASyncDetector_instance.sync_count += 1 + CUDASyncDetector_instance.sync_stacktraces.append(('Tensor.item', stack_str)) + return original_item(self_tensor) + + # Can't easily patch instance methods this way, use __class__ + # Actually let's use a simpler approach: wrap the forward method + + return self + + def __exit__(self, *args): + # Restore originals + torch.cuda.synchronize = self._originals.get('torch.cuda.synchronize', torch.cuda.synchronize) + + def report(self): + if self.sync_count == 0: + print("✅ No CPU-GPU syncs detected during forward pass") + return True + else: + print(f"❌ {self.sync_count} CPU-GPU sync(s) detected:") + for name, stack in self.sync_stacktraces: + print(f"\n Sync via {name}:") + # Print just the relevant lines + for line in stack.split('\n'): + line = line.strip() + if line and ('nvfp4' in line or 'cutedsl' in line or 'bridge' in line): + print(f" {line}") + return False + + +def make_dummy_runner(num_experts=32, hidden_size=7168, intermediate_size=3072, device="cuda"): + """Create a CuTeDSL runner with dummy weights for testing.""" + runner = CuTeDSLMoERunner(num_experts, hidden_size, intermediate_size, device=device) + + # Create minimal dummy weights + l1_fp4 = [torch.randn(3584, intermediate_size * 2 // 2, dtype=torch.float4_e2m1fn_x2, device=device) + for _ in range(num_experts)] + l1_sf = [torch.randn(3584 // 16, intermediate_size * 2, dtype=torch.float8_e4m3fn, device=device) + for _ in range(num_experts)] + l1_gs = [0.1] * num_experts + l2_fp4 = [torch.randn(1536, hidden_size // 2, dtype=torch.float4_e2m1fn_x2, device=device) + for _ in range(num_experts)] + l2_sf = [torch.randn(1536 // 16, hidden_size, dtype=torch.float8_e4m3fn, device=device) + for _ in range(num_experts)] + l2_gs = [0.1] * num_experts + + runner.prepare_weights_direct(l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs) + return runner + + +def test_sync_detection(): + """Test that the sync detector actually catches syncs.""" + print("Testing sync detector...") + x = torch.tensor([1.0, 2.0], device="cuda") + + # This should trigger a sync + try: + val = x.item() + print(f" .item() returned {val} (sync detected)") + except: + print(" .item() blocked (unexpected)") + + # torch.cuda.synchronize + torch.cuda.synchronize() + print(" torch.cuda.synchronize() works (patched)") + print() + + +def test_cudagraph_capture(): + """Test that the runner can be captured in a CUDA graph. + + If any CPU-GPU sync happens during forward, CUDA graph capture + will fail with cudaErrorStreamCaptureInvalidated. + """ + print("=" * 70) + print(" CUDA Graph Capture Test") + print("=" * 70) + + device = "cuda" + num_experts = 4 # Small for testing + hidden_size = 7168 + intermediate_size = 3072 + num_tokens = 2 + top_k = 2 + + print(f"\n Config: {num_experts} experts, {num_tokens} tokens, top_k={top_k}") + + # Create runner with dummy weights + print(" Creating runner with dummy weights...") + runner = make_dummy_runner(num_experts, hidden_size, intermediate_size, device) + + # Warmup: trigger _ensure_stacked and kernel compilation + print(" Warming up (first forward — compiles kernels)...") + hidden_states = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=device) + topk_weights = torch.tensor([[0.6, 0.4]] * num_tokens, dtype=torch.float32, device=device) + topk_ids = torch.tensor([[0, 1], [2, 3]], dtype=torch.int32, device=device) + + with torch.no_grad(): + _ = runner.run(hidden_states, topk_weights, topk_ids) + torch.cuda.synchronize() + print(" Warmup done ✓") + + # Now test CUDA graph capture + print("\n Attempting CUDA graph capture...") + + # Allocate graph and static inputs + g = torch.cuda.CUDAGraph() + static_hidden = hidden_states.clone() + static_weights = topk_weights.clone() + static_ids = topk_ids.clone() + + try: + with torch.cuda.graph(g): + static_output = runner.run(static_hidden, static_weights, static_ids) + print(" CUDA graph capture SUCCEEDED ✓") + + # Replay the graph + print(" Replaying CUDA graph...") + g.replay() + torch.cuda.synchronize() + print(" CUDA graph replay SUCCEEDED ✓") + return True + + except RuntimeError as e: + error_msg = str(e) + if "capture" in error_msg.lower() or "stream" in error_msg.lower(): + print(f" CUDA graph capture FAILED ✗") + print(f" Error: {error_msg[:200]}") + + # Diagnose: check which operations cause syncs + print("\n Diagnosing CPU-GPU syncs in forward pass...") + diagnose_syncs(runner, hidden_states, topk_weights, topk_ids) + return False + else: + raise + + +def diagnose_syncs(runner, hidden_states, topk_weights, topk_ids): + """Run forward with sync detection to find problematic operations.""" + import traceback + + sync_log = [] + + # Patch sync functions + original_sync = torch.cuda.synchronize + original_item = torch.Tensor.item + original_tolist = torch.Tensor.tolist + original_cpu = torch.Tensor.cpu + + def log_sync(name, original_fn): + def wrapper(*args, **kwargs): + stack = traceback.format_stack() + # Find the first frame from our code + relevant = [] + for frame in stack: + if 'nvfp4' in frame or 'cutedsl' in frame or 'bridge' in frame: + relevant.append(frame.strip()) + sync_log.append((name, relevant[:3])) + return original_fn(*args, **kwargs) + return wrapper + + torch.cuda.synchronize = log_sync('synchronize', original_sync) + torch.Tensor.item = log_sync('item', original_item) + torch.Tensor.tolist = log_sync('tolist', original_tolist) + torch.Tensor.cpu = log_sync('cpu', original_cpu) + + try: + with torch.no_grad(): + _ = runner.run(hidden_states, topk_weights, topk_ids) + finally: + # Restore + torch.cuda.synchronize = original_sync + torch.Tensor.item = original_item + torch.Tensor.tolist = original_tolist + torch.Tensor.cpu = original_cpu + + if sync_log: + print(f"\n ❌ Found {len(sync_log)} CPU-GPU sync(s):") + for name, frames in sync_log: + print(f"\n Sync: {name}") + for f in frames: + if f: + print(f" {f}") + else: + print("\n No CPU-GPU syncs detected (capture failure must be from kernel launch)") + + +def main(): + if not torch.cuda.is_available(): + print("No CUDA available, skipping") + return + + # Test 1: Sync detection framework + test_sync_detection() + + # Test 2: CUDA graph capture + success = test_cudagraph_capture() + + print("\n" + "=" * 70) + if success: + print(" ALL TESTS PASSED ✅") + else: + print(" TESTS FAILED ❌ — fix syncs above before cudagraph will work") + print("=" * 70) + + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + main()