From 17c1bdf3719d9d8fdf4f13cb1468e5ed5f70d021 Mon Sep 17 00:00:00 2001 From: PatchyTIS <58251192+PatchouliTIS@users.noreply.github.com> Date: Tue, 17 Mar 2026 13:19:55 +0800 Subject: [PATCH] [Bugfix] dtype mismatch in ngram gpu propose (#37246) Signed-off-by: PatchouliTaisa Co-authored-by: PatchouliTaisa --- vllm/v1/spec_decode/ngram_proposer_gpu.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/ngram_proposer_gpu.py b/vllm/v1/spec_decode/ngram_proposer_gpu.py index 3ff841804..eb24a9c93 100644 --- a/vllm/v1/spec_decode/ngram_proposer_gpu.py +++ b/vllm/v1/spec_decode/ngram_proposer_gpu.py @@ -364,7 +364,9 @@ class NgramProposerGPU: ) token_ids_gpu.scatter_(1, write_positions_long, tokens_to_scatter) - num_tokens_tmp = num_tokens_no_spec + valid_sampled_tokens_count + num_tokens_tmp = (num_tokens_no_spec + valid_sampled_tokens_count).to( + torch.int32 + ) # Compute validity masks. sampled_flags = valid_sampled_tokens_count > 0 @@ -437,7 +439,7 @@ class NgramProposerGPU: ) # Count valid tokens per request. - valid_sampled_tokens_count = valid_mask.sum(dim=1) + valid_sampled_tokens_count = valid_mask.sum(dim=1).to(torch.int32) # Rightmost valid index per row. last_valid_indices = valid_sampled_tokens_count - 1