287 lines
9.9 KiB
Python
287 lines
9.9 KiB
Python
#!/usr/bin python3
|
|
"""
|
|
CUDAGraph compatibility test for CuTeDSL NVFP4 MoE runner.
|
|
|
|
Detects CPU-GPU syncs that break cudagraph capture by patching
|
|
torch CUDA sync functions. Any sync during the forward pass = FAIL.
|
|
|
|
Run on the B200:
|
|
cd /root/nvfp4-megamoe-kernel
|
|
python3 tests/cudagraph_test.py
|
|
"""
|
|
import os
|
|
import sys
|
|
import torch
|
|
import contextlib
|
|
|
|
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
sys.path.insert(0, REPO_ROOT)
|
|
|
|
from vllm.nvfp4_cutedsl import CuTeDSLMoERunner
|
|
|
|
|
|
class CUDASyncDetector:
|
|
"""Context manager that detects CPU-GPU syncs.
|
|
|
|
Patches torch.cuda.synchronize, Tensor.item(), Tensor.cpu(),
|
|
Tensor.numpy(), Tensor.tolist() to raise on call.
|
|
|
|
Usage:
|
|
with CUDASyncDetector() as detector:
|
|
model.forward(x)
|
|
print(f"Syncs detected: {detector.sync_count}")
|
|
"""
|
|
|
|
def __init__(self, allow_compile_syncs=True):
|
|
self.sync_count = 0
|
|
self.sync_stacktraces = []
|
|
self.allow_compile_syncs = allow_compile_syncs
|
|
self._in_compile = False
|
|
self._originals = {}
|
|
|
|
def _make_sync_guard(self, name, original_fn):
|
|
"""Create a guard function that logs sync attempts."""
|
|
def guard(*args, **kwargs):
|
|
import traceback
|
|
stack = traceback.format_stack()
|
|
# Skip the guard frame itself
|
|
stack_str = ''.join(stack[-5:-1])
|
|
|
|
self.sync_count += 1
|
|
self.sync_stacktraces.append((name, stack_str))
|
|
|
|
# Still call the original (we want to detect, not block)
|
|
return original_fn(*args, **kwargs)
|
|
return guard
|
|
|
|
def __enter__(self):
|
|
# Patch torch.cuda.synchronize
|
|
self._originals['torch.cuda.synchronize'] = torch.cuda.synchronize
|
|
torch.cuda.synchronize = self._make_sync_guard(
|
|
'torch.cuda.synchronize', torch.cuda.synchronize
|
|
)
|
|
|
|
# Patch Tensor.item()
|
|
self._originals['Tensor.item'] = torch.Tensor.item
|
|
original_item = torch.Tensor.item
|
|
def item_guard(self_tensor):
|
|
import traceback
|
|
stack = traceback.format_stack()
|
|
stack_str = ''.join(stack[-5:-1])
|
|
# Allow if called from _ensure_stacked or prepare_weights (init time)
|
|
if any(x in stack_str for x in ['_ensure_stacked', 'prepare_weights', 'load_weights', 'warmup']):
|
|
return original_item(self_tensor)
|
|
CUDASyncDetector_instance = self # closure capture
|
|
CUDASyncDetector_instance.sync_count += 1
|
|
CUDASyncDetector_instance.sync_stacktraces.append(('Tensor.item', stack_str))
|
|
return original_item(self_tensor)
|
|
|
|
# Can't easily patch instance methods this way, use __class__
|
|
# Actually let's use a simpler approach: wrap the forward method
|
|
|
|
return self
|
|
|
|
def __exit__(self, *args):
|
|
# Restore originals
|
|
torch.cuda.synchronize = self._originals.get('torch.cuda.synchronize', torch.cuda.synchronize)
|
|
|
|
def report(self):
|
|
if self.sync_count == 0:
|
|
print("✅ No CPU-GPU syncs detected during forward pass")
|
|
return True
|
|
else:
|
|
print(f"❌ {self.sync_count} CPU-GPU sync(s) detected:")
|
|
for name, stack in self.sync_stacktraces:
|
|
print(f"\n Sync via {name}:")
|
|
# Print just the relevant lines
|
|
for line in stack.split('\n'):
|
|
line = line.strip()
|
|
if line and ('nvfp4' in line or 'cutedsl' in line or 'bridge' in line):
|
|
print(f" {line}")
|
|
return False
|
|
|
|
|
|
def make_dummy_runner(num_experts=32, hidden_size=7168, intermediate_size=3072, device="cuda"):
|
|
"""Create a CuTeDSL runner with dummy weights for testing."""
|
|
runner = CuTeDSLMoERunner(num_experts, hidden_size, intermediate_size, device=device)
|
|
|
|
# Create minimal dummy weights
|
|
# Create minimal dummy weights (uint8 → view as float4)
|
|
def rand_fp4(*shape, device="cuda"):
|
|
return torch.randint(0, 256, shape, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2)
|
|
|
|
def rand_sf(*shape, device="cuda"):
|
|
return torch.rand(shape, dtype=torch.float16, device=device).to(torch.float8_e4m3fn)
|
|
|
|
l1_fp4 = [rand_fp4(3584, intermediate_size, device=device) for _ in range(num_experts)]
|
|
l1_sf = [rand_sf(3584 // 16, intermediate_size * 2, device=device) for _ in range(num_experts)]
|
|
l1_gs = [0.1] * num_experts
|
|
l2_fp4 = [rand_fp4(1536, hidden_size // 2, device=device) for _ in range(num_experts)]
|
|
l2_sf = [rand_sf(1536 // 16, hidden_size, device=device) for _ in range(num_experts)]
|
|
l2_gs = [0.1] * num_experts
|
|
|
|
runner.prepare_weights_direct(l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs)
|
|
return runner
|
|
|
|
|
|
def test_sync_detection():
|
|
"""Test that the sync detector actually catches syncs."""
|
|
print("Testing sync detector...")
|
|
x = torch.tensor([1.0, 2.0], device="cuda")
|
|
|
|
# This should trigger a sync
|
|
try:
|
|
val = x.item()
|
|
print(f" .item() returned {val} (sync detected)")
|
|
except:
|
|
print(" .item() blocked (unexpected)")
|
|
|
|
# torch.cuda.synchronize
|
|
torch.cuda.synchronize()
|
|
print(" torch.cuda.synchronize() works (patched)")
|
|
print()
|
|
|
|
|
|
def test_cudagraph_capture():
|
|
"""Test that the runner can be captured in a CUDA graph.
|
|
|
|
If any CPU-GPU sync happens during forward, CUDA graph capture
|
|
will fail with cudaErrorStreamCaptureInvalidated.
|
|
"""
|
|
print("=" * 70)
|
|
print(" CUDA Graph Capture Test")
|
|
print("=" * 70)
|
|
|
|
device = "cuda"
|
|
num_experts = 4 # Small for testing
|
|
hidden_size = 7168
|
|
intermediate_size = 3072
|
|
num_tokens = 2
|
|
top_k = 2
|
|
|
|
print(f"\n Config: {num_experts} experts, {num_tokens} tokens, top_k={top_k}")
|
|
|
|
# Create runner with dummy weights
|
|
print(" Creating runner with dummy weights...")
|
|
runner = make_dummy_runner(num_experts, hidden_size, intermediate_size, device)
|
|
|
|
# Warmup: trigger _ensure_stacked and kernel compilation
|
|
print(" Warming up (first forward — compiles kernels)...")
|
|
hidden_states = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=device)
|
|
topk_weights = torch.tensor([[0.6, 0.4]] * num_tokens, dtype=torch.float32, device=device)
|
|
topk_ids = torch.tensor([[0, 1], [2, 3]], dtype=torch.int32, device=device)
|
|
|
|
with torch.no_grad():
|
|
_ = runner.run(hidden_states, topk_weights, topk_ids)
|
|
torch.cuda.synchronize()
|
|
print(" Warmup done ✓")
|
|
|
|
# Now test CUDA graph capture
|
|
print("\n Attempting CUDA graph capture...")
|
|
|
|
# Allocate graph and static inputs
|
|
g = torch.cuda.CUDAGraph()
|
|
static_hidden = hidden_states.clone()
|
|
static_weights = topk_weights.clone()
|
|
static_ids = topk_ids.clone()
|
|
|
|
try:
|
|
with torch.cuda.graph(g):
|
|
static_output = runner.run(static_hidden, static_weights, static_ids)
|
|
print(" CUDA graph capture SUCCEEDED ✓")
|
|
|
|
# Replay the graph
|
|
print(" Replaying CUDA graph...")
|
|
g.replay()
|
|
torch.cuda.synchronize()
|
|
print(" CUDA graph replay SUCCEEDED ✓")
|
|
return True
|
|
|
|
except RuntimeError as e:
|
|
error_msg = str(e)
|
|
if "capture" in error_msg.lower() or "stream" in error_msg.lower():
|
|
print(f" CUDA graph capture FAILED ✗")
|
|
print(f" Error: {error_msg[:200]}")
|
|
|
|
# Diagnose: check which operations cause syncs
|
|
print("\n Diagnosing CPU-GPU syncs in forward pass...")
|
|
diagnose_syncs(runner, hidden_states, topk_weights, topk_ids)
|
|
return False
|
|
else:
|
|
raise
|
|
|
|
|
|
def diagnose_syncs(runner, hidden_states, topk_weights, topk_ids):
|
|
"""Run forward with sync detection to find problematic operations."""
|
|
import traceback
|
|
|
|
sync_log = []
|
|
|
|
# Patch sync functions
|
|
original_sync = torch.cuda.synchronize
|
|
original_item = torch.Tensor.item
|
|
original_tolist = torch.Tensor.tolist
|
|
original_cpu = torch.Tensor.cpu
|
|
|
|
def log_sync(name, original_fn):
|
|
def wrapper(*args, **kwargs):
|
|
stack = traceback.format_stack()
|
|
# Find the first frame from our code
|
|
relevant = []
|
|
for frame in stack:
|
|
if 'nvfp4' in frame or 'cutedsl' in frame or 'bridge' in frame:
|
|
relevant.append(frame.strip())
|
|
sync_log.append((name, relevant[:3]))
|
|
return original_fn(*args, **kwargs)
|
|
return wrapper
|
|
|
|
torch.cuda.synchronize = log_sync('synchronize', original_sync)
|
|
torch.Tensor.item = log_sync('item', original_item)
|
|
torch.Tensor.tolist = log_sync('tolist', original_tolist)
|
|
torch.Tensor.cpu = log_sync('cpu', original_cpu)
|
|
|
|
try:
|
|
with torch.no_grad():
|
|
_ = runner.run(hidden_states, topk_weights, topk_ids)
|
|
finally:
|
|
# Restore
|
|
torch.cuda.synchronize = original_sync
|
|
torch.Tensor.item = original_item
|
|
torch.Tensor.tolist = original_tolist
|
|
torch.Tensor.cpu = original_cpu
|
|
|
|
if sync_log:
|
|
print(f"\n ❌ Found {len(sync_log)} CPU-GPU sync(s):")
|
|
for name, frames in sync_log:
|
|
print(f"\n Sync: {name}")
|
|
for f in frames:
|
|
if f:
|
|
print(f" {f}")
|
|
else:
|
|
print("\n No CPU-GPU syncs detected (capture failure must be from kernel launch)")
|
|
|
|
|
|
def main():
|
|
if not torch.cuda.is_available():
|
|
print("No CUDA available, skipping")
|
|
return
|
|
|
|
# Test 1: Sync detection framework
|
|
test_sync_detection()
|
|
|
|
# Test 2: CUDA graph capture
|
|
success = test_cudagraph_capture()
|
|
|
|
print("\n" + "=" * 70)
|
|
if success:
|
|
print(" ALL TESTS PASSED ✅")
|
|
else:
|
|
print(" TESTS FAILED ❌ — fix syncs above before cudagraph will work")
|
|
print("=" * 70)
|
|
|
|
sys.exit(0 if success else 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|