From d940607629b03602f34ba4dd75c747162b01aedd Mon Sep 17 00:00:00 2001 From: Yiliu Dong <91178480+qianlihuang@users.noreply.github.com> Date: Fri, 27 Feb 2026 01:31:28 +0800 Subject: [PATCH] [Core] Support `min_tokens` with speculative decoding (#32642) Signed-off-by: qianlihuang Co-authored-by: qianlihuang --- tests/v1/e2e/test_async_scheduling.py | 3 +- .../logits_processors/test_custom_offline.py | 7 ++- vllm/sampling_params.py | 4 +- vllm/v1/sample/logits_processor/__init__.py | 7 +-- vllm/v1/sample/logits_processor/builtin.py | 54 +++++++++++++++++++ vllm/v1/sample/logits_processor/state.py | 4 +- vllm/v1/sample/rejection_sampler.py | 7 +++ 7 files changed, 75 insertions(+), 11 deletions(-) diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index 393c8dbee..042e95386 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -32,8 +32,7 @@ example_prompts = [first_prompt, "In one word, the capital of France is "] + [ default_params = dict( temperature=0.0, # greedy max_tokens=30, - # spec decoding currently doesn't support min_tokens - # min_tokens=28, + min_tokens=28, ) diff --git a/tests/v1/logits_processors/test_custom_offline.py b/tests/v1/logits_processors/test_custom_offline.py index 59317e918..29ec72186 100644 --- a/tests/v1/logits_processors/test_custom_offline.py +++ b/tests/v1/logits_processors/test_custom_offline.py @@ -276,9 +276,12 @@ def test_rejects_custom_logitsprocs( monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork") llm = LLM(**llm_kwargs) - # Require that no logitsprocs have been loaded + # Require that no custom logitsprocs have been loaded + # (built-in processors may exist: MinTokensLogitsProcessor, + # LogitBiasLogitsProcessor, MinPLogitsProcessor) worker = llm.llm_engine.model_executor.driver_worker.worker - assert sum([1 for _ in worker.model_runner.input_batch.logitsprocs.all]) == 0 + for proc in worker.model_runner.input_batch.logitsprocs.all: + assert not isinstance(proc, DummyLogitsProcessor) return if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 4e5885b65..2f015339e 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -678,9 +678,9 @@ class SamplingParams( return # Some sampling parameters are not yet compatible with spec decoding. - if self.min_tokens > 1 or self.min_p > _SAMPLING_EPS or self.logit_bias: + if self.min_p > _SAMPLING_EPS or self.logit_bias: raise ValueError( - "The min_tokens, min_p, and logit_bias sampling parameters " + "The min_p and logit_bias sampling parameters " "are not yet supported with speculative decoding." ) diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index f7b70645f..693f7b125 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -202,10 +202,11 @@ def build_logitsprocs( if custom_logitsprocs: raise ValueError(STR_SPEC_DEC_REJECTS_LOGITSPROCS) logger.warning( - "min_p, logit_bias, and min_tokens parameters won't currently work " - "with speculative decoding enabled." + "min_p and logit_bias parameters won't work with speculative decoding." + ) + return LogitsProcessors( + [MinTokensLogitsProcessor(vllm_config, device, is_pin_memory)] ) - return LogitsProcessors() custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs) return LogitsProcessors( diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 82743f72b..11a52711d 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -3,6 +3,7 @@ from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, TypeVar +import numpy as np import torch from vllm import SamplingParams @@ -236,6 +237,59 @@ class MinTokensLogitsProcessor(LogitsProcessor): logits.index_put_(self.logits_slice, self.neg_inf_tensor) return logits + def apply_with_spec_decode( + self, + logits: torch.Tensor, + num_draft_tokens: list[int], + ) -> torch.Tensor: + """Spec-decode version of apply(). + Priority: ``min_tokens`` > ``stop_token_ids`` / EOS. + Example: ``num_draft_tokens = [2, 3, 1]`` + → ``logits`` shape ``[6, V]``, ``cumsum = [0, 2, 5, 6]`` + → request 0 owns rows 0‑1, request 1 rows 2‑4, request 2 row 5. + """ + if not self.min_toks: + return logits + + num_draft_arr = np.array(num_draft_tokens, dtype=np.int64) + cumsum = np.concatenate([[0], np.cumsum(num_draft_arr)]) + + entries = [ + (req_idx, min_tok, len(out_tok_ids), list(stop_tok_ids)) + for req_idx, (min_tok, out_tok_ids, stop_tok_ids) in self.min_toks.items() + if stop_tok_ids + ] + + if not entries: + return logits + + all_rows: list[np.ndarray] = [] # row indices to mask + all_toks: list[np.ndarray] = [] # stop-token ids at those rows + + for req_idx, min_tok, current_len, stop_toks in entries: + remaining = min_tok - current_len + # How many leading draft positions still need stop-token masking. + n_mask = int(min(max(remaining, 0), num_draft_arr[req_idx])) + + if n_mask > 0: + offset = cumsum[req_idx] + row_indices = np.arange(offset, offset + n_mask, dtype=np.int64) + n_stop = len(stop_toks) + all_rows.append(np.repeat(row_indices, n_stop)) + all_toks.append(np.tile(stop_toks, n_mask)) + + if all_rows: + rows_arr = np.concatenate(all_rows) + toks_arr = np.concatenate(all_toks) + # (row_indices, token_indices) for index_put_ to set -inf. + logits_slice = ( + torch.from_numpy(rows_arr).to(self.device, non_blocking=True), + torch.from_numpy(toks_arr).to(self.device, non_blocking=True), + ) + logits.index_put_(logits_slice, self.neg_inf_tensor) + + return logits + def process_dict_updates( req_entries: dict[int, T], diff --git a/vllm/v1/sample/logits_processor/state.py b/vllm/v1/sample/logits_processor/state.py index c15219da5..41cbba8df 100644 --- a/vllm/v1/sample/logits_processor/state.py +++ b/vllm/v1/sample/logits_processor/state.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterator +from collections.abc import Iterable, Iterator from itertools import chain from typing import TYPE_CHECKING @@ -148,7 +148,7 @@ class BatchUpdateBuilder: class LogitsProcessors: """Encapsulates initialized logitsproc objects.""" - def __init__(self, logitsprocs: Iterator["LogitsProcessor"] | None = None) -> None: + def __init__(self, logitsprocs: Iterable["LogitsProcessor"] | None = None) -> None: self.argmax_invariant: list[LogitsProcessor] = [] self.non_argmax_invariant: list[LogitsProcessor] = [] if logitsprocs: diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 1efceba38..278d421eb 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -10,6 +10,7 @@ import torch.nn as nn from vllm.logger import init_logger from vllm.triton_utils import tl, triton from vllm.v1.outputs import LogprobsLists, LogprobsTensors, SamplerOutput +from vllm.v1.sample.logits_processor.builtin import MinTokensLogitsProcessor from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts from vllm.v1.sample.ops.penalties import apply_all_penalties @@ -292,6 +293,12 @@ class RejectionSampler(nn.Module): logits, bad_words_token_ids, output_token_ids, metadata.num_draft_tokens ) + for processor in sampling_metadata.logitsprocs.non_argmax_invariant: + if isinstance(processor, MinTokensLogitsProcessor): + logits = processor.apply_with_spec_decode( + logits, metadata.num_draft_tokens + ) + return logits @staticmethod