diff --git a/vllm/v1/worker/gpu/sample/logprob.py b/vllm/v1/worker/gpu/sample/logprob.py index 466786766..4317cad9c 100644 --- a/vllm/v1/worker/gpu/sample/logprob.py +++ b/vllm/v1/worker/gpu/sample/logprob.py @@ -68,7 +68,7 @@ def _ranks_kernel( for i in range(0, vocab_size, BLOCK_SIZE): block = i + tl.arange(0, BLOCK_SIZE) logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf")) - n += tl.sum((logits > x).to(tl.int32)) + n += tl.sum((logits >= x).to(tl.int32)) tl.store(output_ptr + req_idx, n)