#!/usr/bin python3 """ CUDAGraph compatibility test for CuTeDSL NVFP4 MoE runner. Detects CPU-GPU syncs that break cudagraph capture by patching torch CUDA sync functions. Any sync during the forward pass = FAIL. Run on the B200: cd /root/nvfp4-megamoe-kernel python3 tests/cudagraph_test.py """ import os import sys import torch import contextlib REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, REPO_ROOT) from vllm.nvfp4_cutedsl import CuTeDSLMoERunner class CUDASyncDetector: """Context manager that detects CPU-GPU syncs. Patches torch.cuda.synchronize, Tensor.item(), Tensor.cpu(), Tensor.numpy(), Tensor.tolist() to raise on call. Usage: with CUDASyncDetector() as detector: model.forward(x) print(f"Syncs detected: {detector.sync_count}") """ def __init__(self, allow_compile_syncs=True): self.sync_count = 0 self.sync_stacktraces = [] self.allow_compile_syncs = allow_compile_syncs self._in_compile = False self._originals = {} def _make_sync_guard(self, name, original_fn): """Create a guard function that logs sync attempts.""" def guard(*args, **kwargs): import traceback stack = traceback.format_stack() # Skip the guard frame itself stack_str = ''.join(stack[-5:-1]) self.sync_count += 1 self.sync_stacktraces.append((name, stack_str)) # Still call the original (we want to detect, not block) return original_fn(*args, **kwargs) return guard def __enter__(self): # Patch torch.cuda.synchronize self._originals['torch.cuda.synchronize'] = torch.cuda.synchronize torch.cuda.synchronize = self._make_sync_guard( 'torch.cuda.synchronize', torch.cuda.synchronize ) # Patch Tensor.item() self._originals['Tensor.item'] = torch.Tensor.item original_item = torch.Tensor.item def item_guard(self_tensor): import traceback stack = traceback.format_stack() stack_str = ''.join(stack[-5:-1]) # Allow if called from _ensure_stacked or prepare_weights (init time) if any(x in stack_str for x in ['_ensure_stacked', 'prepare_weights', 'load_weights', 'warmup']): return original_item(self_tensor) CUDASyncDetector_instance = self # closure capture CUDASyncDetector_instance.sync_count += 1 CUDASyncDetector_instance.sync_stacktraces.append(('Tensor.item', stack_str)) return original_item(self_tensor) # Can't easily patch instance methods this way, use __class__ # Actually let's use a simpler approach: wrap the forward method return self def __exit__(self, *args): # Restore originals torch.cuda.synchronize = self._originals.get('torch.cuda.synchronize', torch.cuda.synchronize) def report(self): if self.sync_count == 0: print("✅ No CPU-GPU syncs detected during forward pass") return True else: print(f"❌ {self.sync_count} CPU-GPU sync(s) detected:") for name, stack in self.sync_stacktraces: print(f"\n Sync via {name}:") # Print just the relevant lines for line in stack.split('\n'): line = line.strip() if line and ('nvfp4' in line or 'cutedsl' in line or 'bridge' in line): print(f" {line}") return False def make_dummy_runner(num_experts=32, hidden_size=7168, intermediate_size=3072, device="cuda"): """Create a CuTeDSL runner with dummy weights for testing.""" runner = CuTeDSLMoERunner(num_experts, hidden_size, intermediate_size, device=device) # Create minimal dummy weights # Create minimal dummy weights (uint8 → view as float4) def rand_fp4(*shape, device="cuda"): return torch.randint(0, 256, shape, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2) def rand_sf(*shape, device="cuda"): return torch.rand(shape, dtype=torch.float16, device=device).to(torch.float8_e4m3fn) l1_fp4 = [rand_fp4(3584, intermediate_size * 2, device=device) for _ in range(num_experts)] l1_sf = [rand_sf(3584 // 16, intermediate_size * 2, device=device) for _ in range(num_experts)] l1_gs = [0.1] * num_experts l2_fp4 = [rand_fp4(1536, hidden_size, device=device) for _ in range(num_experts)] l2_sf = [rand_sf(1536 // 16, hidden_size, device=device) for _ in range(num_experts)] l2_gs = [0.1] * num_experts runner.prepare_weights_direct(l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs) return runner def test_sync_detection(): """Test that the sync detector actually catches syncs.""" print("Testing sync detector...") x = torch.tensor([1.0, 2.0], device="cuda") # This should trigger a sync try: val = x.item() print(f" .item() returned {val} (sync detected)") except: print(" .item() blocked (unexpected)") # torch.cuda.synchronize torch.cuda.synchronize() print(" torch.cuda.synchronize() works (patched)") print() def test_cudagraph_capture(): """Test that the runner can be captured in a CUDA graph. If any CPU-GPU sync happens during forward, CUDA graph capture will fail with cudaErrorStreamCaptureInvalidated. """ print("=" * 70) print(" CUDA Graph Capture Test") print("=" * 70) device = "cuda" num_experts = 4 # Small for testing hidden_size = 7168 intermediate_size = 3072 num_tokens = 2 top_k = 2 print(f"\n Config: {num_experts} experts, {num_tokens} tokens, top_k={top_k}") # Create runner with dummy weights print(" Creating runner with dummy weights...") runner = make_dummy_runner(num_experts, hidden_size, intermediate_size, device) # Warmup: trigger _ensure_stacked and kernel compilation print(" Warming up (first forward — compiles kernels)...") hidden_states = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=device) topk_weights = torch.tensor([[0.6, 0.4]] * num_tokens, dtype=torch.float32, device=device) topk_ids = torch.tensor([[0, 1], [2, 3]], dtype=torch.int32, device=device) with torch.no_grad(): _ = runner.run(hidden_states, topk_weights, topk_ids) torch.cuda.synchronize() print(" Warmup done ✓") # Now test CUDA graph capture print("\n Attempting CUDA graph capture...") # Allocate graph and static inputs g = torch.cuda.CUDAGraph() static_hidden = hidden_states.clone() static_weights = topk_weights.clone() static_ids = topk_ids.clone() try: with torch.cuda.graph(g): static_output = runner.run(static_hidden, static_weights, static_ids) print(" CUDA graph capture SUCCEEDED ✓") # Replay the graph print(" Replaying CUDA graph...") g.replay() torch.cuda.synchronize() print(" CUDA graph replay SUCCEEDED ✓") return True except RuntimeError as e: error_msg = str(e) if "capture" in error_msg.lower() or "stream" in error_msg.lower(): print(f" CUDA graph capture FAILED ✗") print(f" Error: {error_msg[:200]}") # Diagnose: check which operations cause syncs print("\n Diagnosing CPU-GPU syncs in forward pass...") diagnose_syncs(runner, hidden_states, topk_weights, topk_ids) return False else: raise def diagnose_syncs(runner, hidden_states, topk_weights, topk_ids): """Run forward with sync detection to find problematic operations.""" import traceback sync_log = [] # Patch sync functions original_sync = torch.cuda.synchronize original_item = torch.Tensor.item original_tolist = torch.Tensor.tolist original_cpu = torch.Tensor.cpu def log_sync(name, original_fn): def wrapper(*args, **kwargs): stack = traceback.format_stack() # Find the first frame from our code relevant = [] for frame in stack: if 'nvfp4' in frame or 'cutedsl' in frame or 'bridge' in frame: relevant.append(frame.strip()) sync_log.append((name, relevant[:3])) return original_fn(*args, **kwargs) return wrapper torch.cuda.synchronize = log_sync('synchronize', original_sync) torch.Tensor.item = log_sync('item', original_item) torch.Tensor.tolist = log_sync('tolist', original_tolist) torch.Tensor.cpu = log_sync('cpu', original_cpu) try: with torch.no_grad(): _ = runner.run(hidden_states, topk_weights, topk_ids) finally: # Restore torch.cuda.synchronize = original_sync torch.Tensor.item = original_item torch.Tensor.tolist = original_tolist torch.Tensor.cpu = original_cpu if sync_log: print(f"\n ❌ Found {len(sync_log)} CPU-GPU sync(s):") for name, frames in sync_log: print(f"\n Sync: {name}") for f in frames: if f: print(f" {f}") else: print("\n No CPU-GPU syncs detected (capture failure must be from kernel launch)") def main(): if not torch.cuda.is_available(): print("No CUDA available, skipping") return # Test 1: Sync detection framework test_sync_detection() # Test 2: CUDA graph capture success = test_cudagraph_capture() print("\n" + "=" * 70) if success: print(" ALL TESTS PASSED ✅") else: print(" TESTS FAILED ❌ — fix syncs above before cudagraph will work") print("=" * 70) sys.exit(0 if success else 1) if __name__ == "__main__": main()