From 1528e079e2b2cf8a807e4dce86ef05540e16a430 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Tue, 2 Dec 2025 13:25:52 -0800 Subject: [PATCH] [Perf] Avoid pageable HtoD transfer in MinTokensLogitsProcessor (#29826) Signed-off-by: jthomson04 --- vllm/v1/sample/logits_processor/builtin.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 4ee7dc288..82743f72b 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -110,7 +110,7 @@ class MinPLogitsProcessor(LogitsProcessor): # Identify valid tokens using threshold comparison invalid_token_mask = probability_values < adjusted_min_p # Apply mask using boolean indexing - logits[invalid_token_mask] = -float("inf") + logits.masked_fill_(invalid_token_mask, -float("inf")) return logits @@ -178,6 +178,10 @@ class MinTokensLogitsProcessor(LogitsProcessor): self._device_tensor([], torch.int32), ) + self.neg_inf_tensor = torch.tensor( + -float("inf"), dtype=torch.float32, device=self.device + ) + def is_argmax_invariant(self) -> bool: """By censoring stop tokens, min-tokens can change the outcome of the argmax operation in greedy sampling.""" @@ -229,7 +233,7 @@ class MinTokensLogitsProcessor(LogitsProcessor): def apply(self, logits: torch.Tensor) -> torch.Tensor: if self.min_toks: # Inhibit EOS token for requests which have not reached min length - logits[self.logits_slice] = -float("inf") + logits.index_put_(self.logits_slice, self.neg_inf_tensor) return logits