Add sync after sampler for step<3 to catch async CUDA errors early
This commit is contained in:
@@ -1034,6 +1034,9 @@ def main():
|
||||
recent_tokens=all_tokens[-256:],
|
||||
seed=SEED,
|
||||
)
|
||||
# Check for async CUDA errors from sampler
|
||||
if step < 3:
|
||||
torch.cuda.synchronize()
|
||||
next_id = sampled[0].item()
|
||||
|
||||
all_tokens.append(next_id)
|
||||
|
||||
Reference in New Issue
Block a user