[Bugfix] dtype mismatch in ngram gpu propose (#37246)
Signed-off-by: PatchouliTaisa <patchychen@tencent.com> Co-authored-by: PatchouliTaisa <patchychen@tencent.com>
This commit is contained in:
@@ -364,7 +364,9 @@ class NgramProposerGPU:
|
||||
)
|
||||
token_ids_gpu.scatter_(1, write_positions_long, tokens_to_scatter)
|
||||
|
||||
num_tokens_tmp = num_tokens_no_spec + valid_sampled_tokens_count
|
||||
num_tokens_tmp = (num_tokens_no_spec + valid_sampled_tokens_count).to(
|
||||
torch.int32
|
||||
)
|
||||
|
||||
# Compute validity masks.
|
||||
sampled_flags = valid_sampled_tokens_count > 0
|
||||
@@ -437,7 +439,7 @@ class NgramProposerGPU:
|
||||
)
|
||||
|
||||
# Count valid tokens per request.
|
||||
valid_sampled_tokens_count = valid_mask.sum(dim=1)
|
||||
valid_sampled_tokens_count = valid_mask.sum(dim=1).to(torch.int32)
|
||||
|
||||
# Rightmost valid index per row.
|
||||
last_valid_indices = valid_sampled_tokens_count - 1
|
||||
|
||||
Reference in New Issue
Block a user