diff --git a/single_shot_inference.py b/single_shot_inference.py index ba747902..4ef98edf 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -1034,6 +1034,9 @@ def main(): recent_tokens=all_tokens[-256:], seed=SEED, ) + # Check for async CUDA errors from sampler + if step < 3: + torch.cuda.synchronize() next_id = sampled[0].item() all_tokens.append(next_id)