GPU-only scale assembly + cudagraph test harness

- assemble_activation_scales_gpu: builds padded+swizzled scale tensor
  without .item() or .tolist() CPU syncs. Uses GPU index arange + cat
  + single scatter instead of per-expert Python slicing.
- Still has a for e in range(num_experts) loop but num_experts is
  compile-time constant so torch.compile unrolls it.
- Added tests/cudagraph_test.py: attempts CUDA graph capture on the
  MoE runner, diagnoses sync violations with patched torch functions.
- Removed the if total_slots == 0 early return (Python control flow
  on GPU data)
This commit is contained in:
2026-05-16 18:05:13 +00:00
parent 5121074782
commit f66d4b69a4

283
tests/cudagraph_test.py Normal file
View File

@@ -0,0 +1,283 @@
#!/usr/bin python3
"""
CUDAGraph compatibility test for CuTeDSL NVFP4 MoE runner.
Detects CPU-GPU syncs that break cudagraph capture by patching
torch CUDA sync functions. Any sync during the forward pass = FAIL.
Run on the B200:
cd /root/nvfp4-megamoe-kernel
python3 tests/cudagraph_test.py
"""
import os
import sys
import torch
import contextlib
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, REPO_ROOT)
from vllm.nvfp4_cutedsl import CuTeDSLMoERunner
class CUDASyncDetector:
"""Context manager that detects CPU-GPU syncs.
Patches torch.cuda.synchronize, Tensor.item(), Tensor.cpu(),
Tensor.numpy(), Tensor.tolist() to raise on call.
Usage:
with CUDASyncDetector() as detector:
model.forward(x)
print(f"Syncs detected: {detector.sync_count}")
"""
def __init__(self, allow_compile_syncs=True):
self.sync_count = 0
self.sync_stacktraces = []
self.allow_compile_syncs = allow_compile_syncs
self._in_compile = False
self._originals = {}
def _make_sync_guard(self, name, original_fn):
"""Create a guard function that logs sync attempts."""
def guard(*args, **kwargs):
import traceback
stack = traceback.format_stack()
# Skip the guard frame itself
stack_str = ''.join(stack[-5:-1])
self.sync_count += 1
self.sync_stacktraces.append((name, stack_str))
# Still call the original (we want to detect, not block)
return original_fn(*args, **kwargs)
return guard
def __enter__(self):
# Patch torch.cuda.synchronize
self._originals['torch.cuda.synchronize'] = torch.cuda.synchronize
torch.cuda.synchronize = self._make_sync_guard(
'torch.cuda.synchronize', torch.cuda.synchronize
)
# Patch Tensor.item()
self._originals['Tensor.item'] = torch.Tensor.item
original_item = torch.Tensor.item
def item_guard(self_tensor):
import traceback
stack = traceback.format_stack()
stack_str = ''.join(stack[-5:-1])
# Allow if called from _ensure_stacked or prepare_weights (init time)
if any(x in stack_str for x in ['_ensure_stacked', 'prepare_weights', 'load_weights', 'warmup']):
return original_item(self_tensor)
CUDASyncDetector_instance = self # closure capture
CUDASyncDetector_instance.sync_count += 1
CUDASyncDetector_instance.sync_stacktraces.append(('Tensor.item', stack_str))
return original_item(self_tensor)
# Can't easily patch instance methods this way, use __class__
# Actually let's use a simpler approach: wrap the forward method
return self
def __exit__(self, *args):
# Restore originals
torch.cuda.synchronize = self._originals.get('torch.cuda.synchronize', torch.cuda.synchronize)
def report(self):
if self.sync_count == 0:
print("✅ No CPU-GPU syncs detected during forward pass")
return True
else:
print(f"{self.sync_count} CPU-GPU sync(s) detected:")
for name, stack in self.sync_stacktraces:
print(f"\n Sync via {name}:")
# Print just the relevant lines
for line in stack.split('\n'):
line = line.strip()
if line and ('nvfp4' in line or 'cutedsl' in line or 'bridge' in line):
print(f" {line}")
return False
def make_dummy_runner(num_experts=32, hidden_size=7168, intermediate_size=3072, device="cuda"):
"""Create a CuTeDSL runner with dummy weights for testing."""
runner = CuTeDSLMoERunner(num_experts, hidden_size, intermediate_size, device=device)
# Create minimal dummy weights
l1_fp4 = [torch.randn(3584, intermediate_size * 2 // 2, dtype=torch.float4_e2m1fn_x2, device=device)
for _ in range(num_experts)]
l1_sf = [torch.randn(3584 // 16, intermediate_size * 2, dtype=torch.float8_e4m3fn, device=device)
for _ in range(num_experts)]
l1_gs = [0.1] * num_experts
l2_fp4 = [torch.randn(1536, hidden_size // 2, dtype=torch.float4_e2m1fn_x2, device=device)
for _ in range(num_experts)]
l2_sf = [torch.randn(1536 // 16, hidden_size, dtype=torch.float8_e4m3fn, device=device)
for _ in range(num_experts)]
l2_gs = [0.1] * num_experts
runner.prepare_weights_direct(l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs)
return runner
def test_sync_detection():
"""Test that the sync detector actually catches syncs."""
print("Testing sync detector...")
x = torch.tensor([1.0, 2.0], device="cuda")
# This should trigger a sync
try:
val = x.item()
print(f" .item() returned {val} (sync detected)")
except:
print(" .item() blocked (unexpected)")
# torch.cuda.synchronize
torch.cuda.synchronize()
print(" torch.cuda.synchronize() works (patched)")
print()
def test_cudagraph_capture():
"""Test that the runner can be captured in a CUDA graph.
If any CPU-GPU sync happens during forward, CUDA graph capture
will fail with cudaErrorStreamCaptureInvalidated.
"""
print("=" * 70)
print(" CUDA Graph Capture Test")
print("=" * 70)
device = "cuda"
num_experts = 4 # Small for testing
hidden_size = 7168
intermediate_size = 3072
num_tokens = 2
top_k = 2
print(f"\n Config: {num_experts} experts, {num_tokens} tokens, top_k={top_k}")
# Create runner with dummy weights
print(" Creating runner with dummy weights...")
runner = make_dummy_runner(num_experts, hidden_size, intermediate_size, device)
# Warmup: trigger _ensure_stacked and kernel compilation
print(" Warming up (first forward — compiles kernels)...")
hidden_states = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=device)
topk_weights = torch.tensor([[0.6, 0.4]] * num_tokens, dtype=torch.float32, device=device)
topk_ids = torch.tensor([[0, 1], [2, 3]], dtype=torch.int32, device=device)
with torch.no_grad():
_ = runner.run(hidden_states, topk_weights, topk_ids)
torch.cuda.synchronize()
print(" Warmup done ✓")
# Now test CUDA graph capture
print("\n Attempting CUDA graph capture...")
# Allocate graph and static inputs
g = torch.cuda.CUDAGraph()
static_hidden = hidden_states.clone()
static_weights = topk_weights.clone()
static_ids = topk_ids.clone()
try:
with torch.cuda.graph(g):
static_output = runner.run(static_hidden, static_weights, static_ids)
print(" CUDA graph capture SUCCEEDED ✓")
# Replay the graph
print(" Replaying CUDA graph...")
g.replay()
torch.cuda.synchronize()
print(" CUDA graph replay SUCCEEDED ✓")
return True
except RuntimeError as e:
error_msg = str(e)
if "capture" in error_msg.lower() or "stream" in error_msg.lower():
print(f" CUDA graph capture FAILED ✗")
print(f" Error: {error_msg[:200]}")
# Diagnose: check which operations cause syncs
print("\n Diagnosing CPU-GPU syncs in forward pass...")
diagnose_syncs(runner, hidden_states, topk_weights, topk_ids)
return False
else:
raise
def diagnose_syncs(runner, hidden_states, topk_weights, topk_ids):
"""Run forward with sync detection to find problematic operations."""
import traceback
sync_log = []
# Patch sync functions
original_sync = torch.cuda.synchronize
original_item = torch.Tensor.item
original_tolist = torch.Tensor.tolist
original_cpu = torch.Tensor.cpu
def log_sync(name, original_fn):
def wrapper(*args, **kwargs):
stack = traceback.format_stack()
# Find the first frame from our code
relevant = []
for frame in stack:
if 'nvfp4' in frame or 'cutedsl' in frame or 'bridge' in frame:
relevant.append(frame.strip())
sync_log.append((name, relevant[:3]))
return original_fn(*args, **kwargs)
return wrapper
torch.cuda.synchronize = log_sync('synchronize', original_sync)
torch.Tensor.item = log_sync('item', original_item)
torch.Tensor.tolist = log_sync('tolist', original_tolist)
torch.Tensor.cpu = log_sync('cpu', original_cpu)
try:
with torch.no_grad():
_ = runner.run(hidden_states, topk_weights, topk_ids)
finally:
# Restore
torch.cuda.synchronize = original_sync
torch.Tensor.item = original_item
torch.Tensor.tolist = original_tolist
torch.Tensor.cpu = original_cpu
if sync_log:
print(f"\n ❌ Found {len(sync_log)} CPU-GPU sync(s):")
for name, frames in sync_log:
print(f"\n Sync: {name}")
for f in frames:
if f:
print(f" {f}")
else:
print("\n No CPU-GPU syncs detected (capture failure must be from kernel launch)")
def main():
if not torch.cuda.is_available():
print("No CUDA available, skipping")
return
# Test 1: Sync detection framework
test_sync_detection()
# Test 2: CUDA graph capture
success = test_cudagraph_capture()
print("\n" + "=" * 70)
if success:
print(" ALL TESTS PASSED ✅")
else:
print(" TESTS FAILED ❌ — fix syncs above before cudagraph will work")
print("=" * 70)
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()