fix: decode loop also needs int32 token_ids for hash router

This commit is contained in:
2026-06-01 01:58:45 +00:00
parent 905623793b
commit cfd2468c61

View File

@@ -715,8 +715,9 @@ def main():
print(f"\nDecoding (max {MAX_NEW_TOKENS} tokens)...")
for step in range(MAX_NEW_TOKENS):
t1 = time.time()
tid = torch.tensor([all_tokens[-1]], dtype=torch.long, device='cuda:0')
dec_pos = torch.tensor([len(all_tokens)-1], dtype=torch.long, device='cuda:0')
tid_int64 = torch.tensor([all_tokens[-1]], dtype=torch.long, device='cuda:0')
tid = tid_int64.to(torch.int32) # hash router needs int32
dec_pos = torch.tensor([len(all_tokens)-1], dtype=torch.long, device='cuda:0)
X = mHCLayer.init_state(embed(tid_int64))
for li in range(n_layers):
gpu = li % NUM_GPUS