fix: decode loop also needs int32 token_ids for hash router
This commit is contained in:
@@ -715,8 +715,9 @@ def main():
|
||||
print(f"\nDecoding (max {MAX_NEW_TOKENS} tokens)...")
|
||||
for step in range(MAX_NEW_TOKENS):
|
||||
t1 = time.time()
|
||||
tid = torch.tensor([all_tokens[-1]], dtype=torch.long, device='cuda:0')
|
||||
dec_pos = torch.tensor([len(all_tokens)-1], dtype=torch.long, device='cuda:0')
|
||||
tid_int64 = torch.tensor([all_tokens[-1]], dtype=torch.long, device='cuda:0')
|
||||
tid = tid_int64.to(torch.int32) # hash router needs int32
|
||||
dec_pos = torch.tensor([len(all_tokens)-1], dtype=torch.long, device='cuda:0)
|
||||
X = mHCLayer.init_state(embed(tid_int64))
|
||||
for li in range(n_layers):
|
||||
gpu = li % NUM_GPUS
|
||||
|
||||
Reference in New Issue
Block a user