From cfd2468c61abbf83b027b484f2ad353d9425ceee Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 1 Jun 2026 01:58:45 +0000 Subject: [PATCH] fix: decode loop also needs int32 token_ids for hash router --- single_shot_inference.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index 7812a363..1f0c528d 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -715,8 +715,9 @@ def main(): print(f"\nDecoding (max {MAX_NEW_TOKENS} tokens)...") for step in range(MAX_NEW_TOKENS): t1 = time.time() - tid = torch.tensor([all_tokens[-1]], dtype=torch.long, device='cuda:0') - dec_pos = torch.tensor([len(all_tokens)-1], dtype=torch.long, device='cuda:0') + tid_int64 = torch.tensor([all_tokens[-1]], dtype=torch.long, device='cuda:0') + tid = tid_int64.to(torch.int32) # hash router needs int32 + dec_pos = torch.tensor([len(all_tokens)-1], dtype=torch.long, device='cuda:0) X = mHCLayer.init_state(embed(tid_int64)) for li in range(n_layers): gpu = li % NUM_GPUS