test: signal alarm timeout for kernel hang

This commit is contained in:
2026-05-12 15:14:39 +00:00
parent fcd6de0a60
commit 307574bc91

View File

@@ -96,23 +96,24 @@ def test_nvfp4_mega_moe():
# --- Run kernel ---
y = torch.zeros(num_tokens, hidden, dtype=torch.bfloat16, device=device)
print("Calling fp8_nvfp4_mega_moe...")
print("Calling fp8_nvfp4_mega_moe...", flush=True)
import signal
timed_out = False
def handler(signum, frame):
nonlocal timed_out
timed_out = True
raise TimeoutError("Kernel timeout")
signal.signal(signal.SIGALRM, handler)
signal.alarm(15) # 15 second timeout
try:
fp8_nvfp4_mega_moe(y, l1_weights, l2_weights, symm_buffer)
# Use a sync with a manual timeout
done = torch.cuda.Event()
done.record()
import time
start = time.time()
while not done.query():
if time.time() - start > 10:
print("TIMEOUT: kernel did not complete in 10s")
break
time.sleep(0.1)
else:
torch.cuda.synchronize()
print(f"SUCCESS! y stats: min={y.min().item():.4f} max={y.max().item():.4f} mean={y.mean().item():.4f} nonzero={torch.count_nonzero(y).item()}")
torch.cuda.synchronize()
signal.alarm(0)
print(f"SUCCESS! y stats: min={y.min().item():.4f} max={y.max().item():.4f} mean={y.mean().item():.4f} nonzero={torch.count_nonzero(y).item()}")
except TimeoutError:
print("TIMEOUT: kernel did not complete in 15s (GPU hang?)")
except Exception as e:
signal.alarm(0)
print(f"FAILED: {e}")
raise