Add sync after sampler for step<3 to catch async CUDA errors early

This commit is contained in:
2026-06-01 22:32:40 +00:00
parent 5493a8727e
commit 6f4bbc997a

View File

@@ -1034,6 +1034,9 @@ def main():
recent_tokens=all_tokens[-256:],
seed=SEED,
)
# Check for async CUDA errors from sampler
if step < 3:
torch.cuda.synchronize()
next_id = sampled[0].item()
all_tokens.append(next_id)