test: signal alarm timeout for kernel hang
This commit is contained in:
@@ -96,23 +96,24 @@ def test_nvfp4_mega_moe():
|
||||
|
||||
# --- Run kernel ---
|
||||
y = torch.zeros(num_tokens, hidden, dtype=torch.bfloat16, device=device)
|
||||
print("Calling fp8_nvfp4_mega_moe...")
|
||||
print("Calling fp8_nvfp4_mega_moe...", flush=True)
|
||||
import signal
|
||||
timed_out = False
|
||||
def handler(signum, frame):
|
||||
nonlocal timed_out
|
||||
timed_out = True
|
||||
raise TimeoutError("Kernel timeout")
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(15) # 15 second timeout
|
||||
try:
|
||||
fp8_nvfp4_mega_moe(y, l1_weights, l2_weights, symm_buffer)
|
||||
# Use a sync with a manual timeout
|
||||
done = torch.cuda.Event()
|
||||
done.record()
|
||||
import time
|
||||
start = time.time()
|
||||
while not done.query():
|
||||
if time.time() - start > 10:
|
||||
print("TIMEOUT: kernel did not complete in 10s")
|
||||
break
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
torch.cuda.synchronize()
|
||||
print(f"SUCCESS! y stats: min={y.min().item():.4f} max={y.max().item():.4f} mean={y.mean().item():.4f} nonzero={torch.count_nonzero(y).item()}")
|
||||
torch.cuda.synchronize()
|
||||
signal.alarm(0)
|
||||
print(f"SUCCESS! y stats: min={y.min().item():.4f} max={y.max().item():.4f} mean={y.mean().item():.4f} nonzero={torch.count_nonzero(y).item()}")
|
||||
except TimeoutError:
|
||||
print("TIMEOUT: kernel did not complete in 15s (GPU hang?)")
|
||||
except Exception as e:
|
||||
signal.alarm(0)
|
||||
print(f"FAILED: {e}")
|
||||
raise
|
||||
|
||||
|
||||
Reference in New Issue
Block a user