diff --git a/docs/features/reasoning_outputs.md b/docs/features/reasoning_outputs.md index cd66863a1..4a9f279e0 100644 --- a/docs/features/reasoning_outputs.md +++ b/docs/features/reasoning_outputs.md @@ -240,6 +240,81 @@ response = client.chat.completions.create( ) ``` +## Thinking Budget Control + +Some models, such as [Qwen3](https://qwen.readthedocs.io/en/latest/getting_started/quickstart.html#thinking-budget), [DeepSeek](https://www.alibabacloud.com/help/en/model-studio/deep-thinking), and [Nemotron3](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16), support a thinking budget that limits the maximum number of tokens used for reasoning. + +Token counting starts from `think_start_str`. Once the reasoning token count reaches the configured `thinking_token_budget`, vLLM forces the model to produce `think_end_str`, effectively terminating the reasoning block. + +To use this feature: + +- `--reasoning-parser` enables reasoning extraction. +- `--reasoning-config` defines the reasoning boundary tokens (e.g., `think_start_str`, `think_end_str`). +- `thinking_token_budget` (a sampling parameter) sets the per-request reasoning token limit. + +If `thinking_token_budget` is not specified, no explicit reasoning limit is applied beyond normal generation constraints such as `max_tokens`. + +`--reasoning-config` accepts a JSON object corresponding to +[ReasoningConfig][vllm.config.ReasoningConfig] with the following fields: + +| Field | Type | Description | +|-------------------|----------------|--------------------------------------------------| +| `think_start_str` | `str \| null` | String that marks the start of reasoning content | +| `think_end_str` | `str \| null` | String that marks the end of reasoning content | + +!!! note + `think_end_str` can include a transition phrase before the think end token. For example, setting `think_end_str` to `"I have to give the solution based on the thinking directly now."` instructs the model to emit that phrase when the budget is exhausted, making the reasoning termination more natural. + +### Online Serving + +```bash +vllm serve Qwen/Qwen3-0.6B \ + --reasoning-parser qwen3 \ + --reasoning-config '{"think_start_str": "", "think_end_str": "I have to give the solution based on the thinking directly now."}' +``` + +Then make a request with `thinking_token_budget` to limit the reasoning tokens: + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen3-0.6B", + "messages": [ + { "role": "user", "content": "9.11 and 9.8, which is greater?" } + ], + "extra_body": { + "thinking_token_budget": 10 + } + }' +``` + +### Offline Inference + +```python +from vllm import LLM, SamplingParams +from vllm.config import ReasoningConfig + +llm = LLM( + model="Qwen/Qwen3-0.6B", + reasoning_config=ReasoningConfig( + think_start_str="", + think_end_str="I have to give the solution based on the thinking directly now.", + ), +) + +sampling_params = SamplingParams(thinking_token_budget=10) + +messages = [ + {"role": "user", "content": "9.11 and 9.8, which is greater?"}, +] + +outputs = llm.chat(messages, sampling_params=sampling_params) + +for output in outputs: + print("text:", output.outputs[0].text) +``` + ## Limitations - The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`). diff --git a/tests/v1/entrypoints/openai/test_thinking_token_budget.py b/tests/v1/entrypoints/openai/test_thinking_token_budget.py new file mode 100644 index 000000000..f574b07b6 --- /dev/null +++ b/tests/v1/entrypoints/openai/test_thinking_token_budget.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""E2E tests for thinking_token_budget with reasoning models.""" + +import openai +import pytest +import pytest_asyncio + +from tests.utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen3-0.6B" +MESSAGES = [{"role": "user", "content": "What is 1+1? Be concise."}] +THINK_BUDGET = 5 + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--reasoning-parser", + "qwen3", + "--reasoning-config", + '{"think_start_str": "", "think_end_str": ""}', + "--max-model-len", + "2048", + "--enforce-eager", + "--no-async-scheduling", + ] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +async def test_thinking_token_budget_mixed_requests(client: openai.AsyncOpenAI): + """Test that mixed requests (some with thinking_token_budget, some without) + complete successfully without errors.""" + + response_with_budget = await client.chat.completions.create( + model=MODEL_NAME, + messages=MESSAGES, + max_tokens=100, + extra_body={"thinking_token_budget": THINK_BUDGET}, + ) + response_without_budget = await client.chat.completions.create( + model=MODEL_NAME, + messages=MESSAGES, + max_tokens=100, + ) + + msg_with = response_with_budget.choices[0].message + msg_without = response_without_budget.choices[0].message + + assert msg_with.content or getattr(msg_with, "reasoning", None) + assert msg_without.content or getattr(msg_without, "reasoning", None) + + +@pytest.mark.asyncio +async def test_thinking_token_budget_limits_reasoning(client: openai.AsyncOpenAI): + """Test that thinking_token_budget limits the number of reasoning tokens. + + In streaming mode each reasoning delta corresponds to one token, so + counting non-empty reasoning_content chunks gives the exact token count. + """ + + reasoning_token_count = 0 + stream = await client.chat.completions.create( + model=MODEL_NAME, + messages=MESSAGES, + max_tokens=100, + stream=True, + extra_body={"thinking_token_budget": THINK_BUDGET}, + ) + async for chunk in stream: + delta = chunk.choices[0].delta + if getattr(delta, "reasoning", None): + reasoning_token_count += 1 + + assert reasoning_token_count == THINK_BUDGET, ( + f"reasoning tokens ({reasoning_token_count}) != " + f"thinking_token_budget ({THINK_BUDGET})" + ) diff --git a/tests/v1/logits_processors/test_correctness.py b/tests/v1/logits_processors/test_correctness.py index dac7ffed6..792168877 100644 --- a/tests/v1/logits_processors/test_correctness.py +++ b/tests/v1/logits_processors/test_correctness.py @@ -30,6 +30,7 @@ from vllm.v1.sample.logits_processor import ( MinPLogitsProcessor, MinTokensLogitsProcessor, MoveDirectionality, + ThinkingTokenBudgetLogitsProcessor, build_logitsprocs, ) from vllm.v1.sample.metadata import SamplingMetadata @@ -47,6 +48,11 @@ MIN_TOKENS_LEN_THRESHOLD = 5 REQS_PER_LOGITPROC = 50 STR_NO_LOGITPROC = "none" +# ThinkingTokenBudgetLogitsProcessor testing constants +THINKING_TOKEN_BUDGET = 5 +THINK_START_TOKEN_ID = 999 +THINK_END_TOKEN_ID = 998 + # LogitsProcessor subclass or "none" LogitprocType: TypeAlias = type[LogitsProcessor] | str @@ -67,9 +73,24 @@ class LogitsProcsRequestParams: self.workload_index = workload_index self.logitproc_type = logitproc_type # Number of output tokens is randomly 0 or twice the min-tokens - # threshold which will be used in testing. Output token values - # don't matter *for these tests* so use 0 as a dummy value - self.out_tokens = [0] * (MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2)) + # threshold which will be used in testing. + # Generate diverse random tokens for all processors (more realistic) + num_tokens = MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2) + if num_tokens > 0: + # Use diverse random tokens + self.out_tokens = [random.randint(1, 950) for _ in range(num_tokens)] + # Set first token for ThinkingTokenBudget testing + is_thinking_processor = ( + logitproc_type is ThinkingTokenBudgetLogitsProcessor + or ( + hasattr(logitproc_type, "__name__") + and logitproc_type.__name__ == "ThinkingTokenBudgetLogitsProcessor" + ) + ) + if is_thinking_processor: + self.out_tokens[0] = THINK_START_TOKEN_ID + else: + self.out_tokens = [] self.prompt_tokens = [] self.params = _sampling_params_from_logitproc(logitproc_type) @@ -79,6 +100,13 @@ class LogitsProcsRequestParams: return f"MyClass({summ})" +class MockReasoningConfig: + """Mock reasoning config for testing ThinkingTokenBudgetLogitsProcessor.""" + + think_start_token_ids = [THINK_START_TOKEN_ID] + think_end_token_ids = [THINK_END_TOKEN_ID] + + def _generate_fake_sampling_metadata( num_output_tokens: int, batch_size: int, @@ -97,8 +125,12 @@ def _generate_fake_sampling_metadata( 0, vocab_size, size=np.random.randint(1, MAX_NUM_PROMPT_TOKENS) ).tolist() ) + + vllm_config = VllmConfig() + vllm_config.reasoning_config = MockReasoningConfig() + logitsprocs = build_logitsprocs( - vllm_config=VllmConfig(), + vllm_config=vllm_config, device=device, is_pin_memory=PIN_MEMORY_AVAILABLE, is_pooling_model=False, @@ -403,6 +435,127 @@ def _min_tokens_validate( ) +def _thinking_budget_params(kwargs: dict) -> None: + """Set SamplingParams kwargs for thinking token budget tests""" + kwargs["thinking_token_budget"] = THINKING_TOKEN_BUDGET + + +def _thinking_budget_validate( + test_fakes: LogitsprocsTestFakes, + persistent_batch: list[LogitsProcsRequestParams], + logits_new: torch.Tensor, + batch_index: int, + request_params: LogitsProcsRequestParams, + step_idx: int, +) -> None: + """Validate thinking token budget processor behavior""" + # Get the ThinkingTokenBudgetLogitsProcessor instance + tb_processor: ThinkingTokenBudgetLogitsProcessor = next( + test_fakes.get_logitsprocs_by_cls(ThinkingTokenBudgetLogitsProcessor) + ) + + # Get current request state + state = tb_processor._state.get(batch_index) + params = request_params.params + + # Validate thinking token budget configuration + if hasattr(params, "thinking_token_budget") and params.thinking_token_budget: + # State should exist for requests with thinking_token_budget + if state is None: + _raise_error_invalid( + msg_suffix=( + f"Expected state for batch {batch_index} " + f"with thinking_token_budget={params.thinking_token_budget}" + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + ) + + # Validate budget matches what was set + expected_budget = params.thinking_token_budget + actual_budget = state["thinking_token_budget"] + + if actual_budget != expected_budget: + _raise_error_invalid( + msg_suffix=( + f"Budget mismatch: expected {expected_budget}, got {actual_budget}" + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + ) + + # Check if we're in thinking mode and validate token counting + output_tokens = request_params.out_tokens + + # Find if thinking has started in output tokens + thinking_started = False + start_tokens = tb_processor.think_start_token_ids + + if len(start_tokens) > 0: + for i in range(len(output_tokens) - len(start_tokens) + 1): + if output_tokens[i : i + len(start_tokens)] == start_tokens: + thinking_started = True + break + + if thinking_started: + # If budget is exceeded, validate end token forcing + think_count = state["think_count"] + budget = state["thinking_token_budget"] + + if think_count >= budget: + if not state["in_end"]: + _raise_error_invalid( + msg_suffix=( + f"Budget exceeded ({think_count} >= " + f"{budget}) but not " + "forcing end tokens" + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + ) + + # Validate that only end tokens are allowed + end_tokens = tb_processor.think_end_token_ids + if len(end_tokens) > 0: + expected_end_token_id = end_tokens[ + min(state["end_count"], len(end_tokens) - 1) + ] + + # Check logits masking + batch_logits = logits_new[batch_index] + for token_id in range(len(batch_logits)): + logit_value = batch_logits[token_id] + + if token_id == expected_end_token_id: + # End token should not be masked + if logit_value == -float("inf"): + _raise_error_invalid( + msg_suffix=( + f"End token {token_id} should not be " + "masked but is" + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + ) + else: + # All other tokens should be masked when forcing end + if logit_value != -float("inf"): + _raise_error_invalid( + msg_suffix=( + f"Token {token_id} should be masked " + f"when forcing end tokens, but " + f"logit={logit_value}" + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + ) + + def _none_validate( test_fakes: LogitsprocsTestFakes, persistent_batch: list[LogitsProcsRequestParams], @@ -449,20 +602,30 @@ logitsprocs_test_mapping = { MinTokensLogitsProcessor: LogitsprocTestHelpers( gen_request_fxn=_min_tokens_params, eval_fxn=_min_tokens_validate ), + ThinkingTokenBudgetLogitsProcessor: LogitsprocTestHelpers( + gen_request_fxn=_thinking_budget_params, eval_fxn=_thinking_budget_validate + ), } def _get_test_cases() -> list[list[str]]: """Each test case is a set of logitsprocs""" logitsprocs_types = list(logitsprocs_test_mapping.keys()) + + # Isolate ThinkingTokenBudgetLogitsProcessor from all other processors + # to avoid unexpected modification of logits interference + thinking_processor = ThinkingTokenBudgetLogitsProcessor + other_processors = [ + p + for p in logitsprocs_types + if p != STR_NO_LOGITPROC and p != thinking_processor + ] + return ( [[STR_NO_LOGITPROC]] - + [ - [logitproc_type, STR_NO_LOGITPROC] - for logitproc_type in logitsprocs_types - if logitproc_type != STR_NO_LOGITPROC - ] - + [logitsprocs_types] + + [[logitproc_type, STR_NO_LOGITPROC] for logitproc_type in other_processors] + + [other_processors] + + [[thinking_processor]] ) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 452fb0466..d5a3e9bfd 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -33,6 +33,7 @@ from vllm.config.offload import ( from vllm.config.parallel import EPLBConfig, ParallelConfig from vllm.config.pooler import PoolerConfig from vllm.config.profiler import ProfilerConfig +from vllm.config.reasoning import ReasoningConfig from vllm.config.scheduler import SchedulerConfig from vllm.config.speculative import SpeculativeConfig from vllm.config.speech_to_text import SpeechToTextConfig @@ -101,6 +102,8 @@ __all__ = [ "ParallelConfig", # From vllm.config.pooler "PoolerConfig", + # From vllm.config.reasoning + "ReasoningConfig", # From vllm.config.scheduler "SchedulerConfig", # From vllm.config.speculative diff --git a/vllm/config/reasoning.py b/vllm/config/reasoning.py new file mode 100644 index 000000000..872e05580 --- /dev/null +++ b/vllm/config/reasoning.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import field + +from vllm.config.model import ModelConfig +from vllm.config.utils import config +from vllm.tokenizers import cached_tokenizer_from_config + + +@config +class ReasoningConfig: + """Configuration for reasoning models. + + Set `think_start_str` and `think_end_str` to the strings that delimit + the reasoning block (e.g. `""` and `""`). The + corresponding token IDs are derived automatically via + `initialize_token_ids` and are not intended to be set directly. + """ + + # NOTE: These parameters are temporary, the intent is to derive them + # automatically from the reasoning parser in a future version. + think_start_str: str = "" + """String that indicates the start of reasoning.""" + think_end_str: str = "" + """String that indicates the end of reasoning content.""" + + _think_start_token_ids: list[int] | None = field( + default=None, init=False, repr=False + ) + """Private backing field for `think_start_token_ids`. Set by + `initialize_token_ids`. Not intended to be configured directly.""" + _think_end_token_ids: list[int] | None = field(default=None, init=False, repr=False) + """Private backing field for `think_end_token_ids`. Set by + `initialize_token_ids`. Not intended to be configured directly.""" + + @property + def think_start_token_ids(self) -> list[int] | None: + """Token IDs derived from `think_start_str`. Set automatically by + `initialize_token_ids`. Not intended to be configured directly.""" + return self._think_start_token_ids + + @property + def think_end_token_ids(self) -> list[int] | None: + """Token IDs derived from `think_end_str`. Set automatically by + `initialize_token_ids`. Not intended to be configured directly.""" + return self._think_end_token_ids + + def initialize_token_ids(self, model_config: ModelConfig) -> None: + """Initialize reasoning token IDs from strings using the tokenizer.""" + if ( + self._think_start_token_ids is not None + and self._think_end_token_ids is not None + ): + return + + tokenizer = cached_tokenizer_from_config(model_config=model_config) + + self._think_start_token_ids = tokenizer.encode( + self.think_start_str, add_special_tokens=False + ) + self._think_end_token_ids = tokenizer.encode( + self.think_end_str, add_special_tokens=False + ) + + if not self._think_start_token_ids or not self._think_end_token_ids: + raise ValueError( + f"ReasoningConfig: failed to tokenize reasoning strings: " + f"think_start_str='{self.think_start_str}', " + f"think_end_str='{self.think_end_str}'. " + "Ensure the strings are valid tokens in the model's vocabulary." + ) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index a8088de63..c7449840d 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -40,6 +40,7 @@ from .observability import ObservabilityConfig from .offload import OffloadConfig from .parallel import ParallelConfig from .profiler import ProfilerConfig +from .reasoning import ReasoningConfig from .scheduler import SchedulerConfig from .speculative import EagleModelTypes, NgramGPUTypes, SpeculativeConfig from .structured_outputs import StructuredOutputsConfig @@ -302,6 +303,8 @@ class VllmConfig: # type: ignore[misc] """The configurations for event publishing.""" ec_transfer_config: ECTransferConfig | None = None """The configurations for distributed EC cache transfer.""" + reasoning_config: ReasoningConfig | None = None + """The configurations for reasoning model.""" # some opaque config, only used to provide additional information # for the hash computation, mainly used for testing, debugging or out of # tree config registration. @@ -1143,6 +1146,9 @@ class VllmConfig: # type: ignore[misc] if not self.instance_id: self.instance_id = random_uuid()[:5] + if self.reasoning_config is not None and self.model_config is not None: + self.reasoning_config.initialize_token_ids(self.model_config) + # Hybrid KV cache manager (HMA) runtime rules: # - Explicit enable (--no-disable-kv-cache-manager): error if runtime # disables it diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e0d5236bc..b7276a345 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -53,6 +53,7 @@ from vllm.config import ( PoolerConfig, PrefetchOffloadConfig, ProfilerConfig, + ReasoningConfig, SchedulerConfig, SpeculativeConfig, StructuredOutputsConfig, @@ -581,6 +582,7 @@ class EngineArgs: kv_events_config: KVEventsConfig | None = None ec_transfer_config: ECTransferConfig | None = None + reasoning_config: ReasoningConfig = get_field(VllmConfig, "reasoning_config") generation_config: str = ModelConfig.generation_config enable_sleep_mode: bool = ModelConfig.enable_sleep_mode @@ -1297,6 +1299,7 @@ class EngineArgs: vllm_group.add_argument( "--attention-config", "-ac", **vllm_kwargs["attention_config"] ) + vllm_group.add_argument("--reasoning-config", **vllm_kwargs["reasoning_config"]) vllm_group.add_argument("--kernel-config", **vllm_kwargs["kernel_config"]) vllm_group.add_argument( "--additional-config", **vllm_kwargs["additional_config"] @@ -1958,6 +1961,7 @@ class EngineArgs: kv_transfer_config=self.kv_transfer_config, kv_events_config=self.kv_events_config, ec_transfer_config=self.ec_transfer_config, + reasoning_config=self.reasoning_config, profiler_config=self.profiler_config, additional_config=self.additional_config, optimization_level=self.optimization_level, diff --git a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py index 61763a3b6..09ef448f4 100644 --- a/vllm/entrypoints/openai/chat_completion/protocol.py +++ b/vllm/entrypoints/openai/chat_completion/protocol.py @@ -180,6 +180,7 @@ class ChatCompletionRequest(OpenAIBaseModel): | None ) = "none" reasoning_effort: Literal["none", "low", "medium", "high"] | None = None + thinking_token_budget: int | None = None include_reasoning: bool = True parallel_tool_calls: bool | None = True @@ -515,6 +516,7 @@ class ChatCompletionRequest(OpenAIBaseModel): structured_outputs=self.structured_outputs, logit_bias=self.logit_bias, bad_words=self.bad_words, + thinking_token_budget=self.thinking_token_budget, allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, skip_clone=True, # Created fresh per request, safe to skip clone diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index f7a2e8b3f..dc3e1d49c 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -281,6 +281,8 @@ class SamplingParams( _bad_words_token_ids: list[list[int]] | None = None skip_reading_prefix_cache: bool | None = None + thinking_token_budget: int | None = None + """Maximum number of tokens allowed for thinking operations.""" repetition_detection: RepetitionDetectionParams | None = None """Parameters for detecting repetitive N-gram patterns in output tokens. @@ -304,6 +306,7 @@ class SamplingParams( stop: str | list[str] | None = None, stop_token_ids: list[int] | None = None, bad_words: list[str] | None = None, + thinking_token_budget: int | None = None, include_stop_str_in_output: bool = False, ignore_eos: bool = False, max_tokens: int | None = 16, @@ -344,6 +347,7 @@ class SamplingParams( stop=stop, stop_token_ids=stop_token_ids, bad_words=bad_words, + thinking_token_budget=thinking_token_budget, include_stop_str_in_output=include_stop_str_in_output, ignore_eos=ignore_eos, max_tokens=max_tokens, @@ -858,6 +862,7 @@ class SamplingParams( f"stop={self.stop}, " f"stop_token_ids={self.stop_token_ids}, " f"bad_words={self.bad_words}, " + f"thinking_token_budget={self.thinking_token_budget}, " f"include_stop_str_in_output={self.include_stop_str_in_output}, " f"ignore_eos={self.ignore_eos}, " f"max_tokens={self.max_tokens}, " diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index aab560544..b77b9277a 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -99,6 +99,16 @@ class InputProcessor: self.structured_outputs_config, self.tokenizer, ) + + if ( + params.thinking_token_budget is not None + and self.vllm_config.reasoning_config is None + ): + raise ValueError( + "thinking_token_budget is set but reasoning_config is " + "not configured. Please set --reasoning-config to use " + "thinking_token_budget." + ) elif isinstance(params, PoolingParams): supported_pooling_tasks = [ task for task in supported_tasks if task in POOLING_TASKS diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index 2cb89e1ea..fb4a046fc 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -18,6 +18,7 @@ from vllm.v1.sample.logits_processor.builtin import ( LogitBiasLogitsProcessor, MinPLogitsProcessor, MinTokensLogitsProcessor, + ThinkingTokenBudgetLogitsProcessor, process_dict_updates, ) from vllm.v1.sample.logits_processor.interface import ( @@ -50,6 +51,7 @@ BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [ MinTokensLogitsProcessor, LogitBiasLogitsProcessor, MinPLogitsProcessor, + ThinkingTokenBudgetLogitsProcessor, ] @@ -354,4 +356,5 @@ __all__ = [ "STR_POOLING_REJECTS_LOGITSPROCS", "LOGITSPROCS_GROUP", "AdapterLogitsProcessor", + "ThinkingTokenBudgetLogitsProcessor", ] diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 11a52711d..c92f33402 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar import numpy as np import torch @@ -291,6 +291,263 @@ class MinTokensLogitsProcessor(LogitsProcessor): return logits +class ThinkingTokenBudgetLogitsProcessor(LogitsProcessor): + """Limits the number of tokens allowed inside a 'thinking' section.""" + + def __init__( + self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool + ): + reasoning_config = vllm_config.reasoning_config + max_num_reqs = vllm_config.scheduler_config.max_num_seqs + + # Check if thinking is enabled + self.is_enabled = reasoning_config is not None + + self.think_start_token_ids = getattr( + reasoning_config, "think_start_token_ids", [] + ) + self.think_end_token_ids = getattr(reasoning_config, "think_end_token_ids", []) + + self.pin_memory = is_pin_memory + self.device = device + # Per-request state tracking for thinking token management + # Key: request_index, Value: state dict containing: + # "in_think": bool - currently in thinking mode + # "in_end": bool - currently forcing end tokens output + # "check_count_down": int - steps remaining until next think + # start/end token parsing + # "think_count": int - number of thinking tokens generated + # "end_count": int - number of end tokens forced so far + # "thinking_token_budget": int - max allowed thinking tokens + # "output_tok_ids": list[int] - generated output tokens + # "prev_output_length": int - previous output length for + # incremental processing + self._state: dict[int, dict[str, Any]] = {} + + # Preallocate reusable tensors + self.mask = torch.zeros(max_num_reqs, dtype=torch.bool, device=device) + self.force_token_ids = torch.full( + (max_num_reqs,), -1, dtype=torch.long, device=device + ) + + @staticmethod + def _find_last_sequence_index(target_list: list[int], token_ids: list[int]) -> int: + """ + Returns the index of the last occurrence of token_ids in target_list. + + Args: + target_list (list[int]): The list of token IDs. + token_ids (list[int]): The sequence of token IDs to find. + """ + if not token_ids: + return -1 + for i in range(len(target_list) - len(token_ids), -1, -1): + if target_list[i : i + len(token_ids)] == token_ids: + return i + return -1 + + def _init_state_entry( + self, prompt_tok_ids: list[int] | None, thinking_token_budget: int + ) -> dict[str, Any]: + """Initializes the tracking state for a given sequence index.""" + if prompt_tok_ids is None: + last_start = -1 + last_end = -1 + in_think = False + think_count = 0 + else: + last_start = self._find_last_sequence_index( + prompt_tok_ids, self.think_start_token_ids + ) + last_end = self._find_last_sequence_index( + prompt_tok_ids, self.think_end_token_ids + ) + in_think = last_start > last_end + if in_think: + think_count = len(prompt_tok_ids) - ( + last_start + len(self.think_start_token_ids) + ) + else: + think_count = 0 + + return { + "in_think": in_think, # Currently in thinking mode + "in_end": in_think and thinking_token_budget == 0, + "check_count_down": thinking_token_budget, + "think_count": think_count, # Number of tokens in thinking section + "end_count": 0, # Number of end tokens forced so far + "prompt_tok_ids": prompt_tok_ids, + "output_tok_ids": [], + "thinking_token_budget": thinking_token_budget, + "prev_output_length": 0, + # Track previous output length for incremental updates + } + + def _update_think_state(self, state: dict[str, Any]): + """Updates the state based on newly generated output tokens.""" + if not state.get("in_end", False) and state.get("check_count_down", 0) > 0: + state["check_count_down"] -= 1 + return + + output = state.get("output_tok_ids", []) + if not output: + return + + # Track previous output length for incremental processing + prev_length = state.get("prev_output_length", 0) + current_length = len(output) + + if current_length <= prev_length: + return + + # Process only newly added tokens + new_tokens = output[prev_length:] + state["prev_output_length"] = current_length + + # Check if new tokens contain think start or end sequences + start_len = len(self.think_start_token_ids) + end_len = len(self.think_end_token_ids) + + # Look for think sequences in recent tokens (including boundary) + # Check overlapping regions where sequences might span boundaries + check_start_idx = max(0, prev_length - max(start_len, end_len) + 1) + recent_tokens = output[check_start_idx:] + + # Find any think start/end sequences in recent tokens + recent_start_pos = self._find_last_sequence_index( + recent_tokens, self.think_start_token_ids + ) + recent_end_pos = self._find_last_sequence_index( + recent_tokens, self.think_end_token_ids + ) + + # Update state based on recent sequences + if not state["in_end"]: + if recent_start_pos >= 0 and recent_end_pos >= 0: + if recent_start_pos > recent_end_pos: + # Case: ......... - entering think mode + absolute_start_pos = check_start_idx + recent_start_pos + new_think_count = current_length - (absolute_start_pos + start_len) + state["in_think"] = True + state["think_count"] = new_think_count + else: + # Case: ......... - exiting think mode + state["in_think"] = False + state["think_count"] = 0 + elif recent_start_pos >= 0: + # Found think start - entering think mode + absolute_start_pos = check_start_idx + recent_start_pos + new_think_count = current_length - (absolute_start_pos + start_len) + state["in_think"] = True + state["think_count"] = new_think_count + elif recent_end_pos >= 0: + # Found think end - exiting think mode + state["in_think"] = False + state["think_count"] = 0 + elif state["in_think"]: + # Continue thinking mode, increment count by new tokens + state["think_count"] += len(new_tokens) + + # Set countdown based on current state + if state["in_think"]: + remaining_budget = max( + 0, state["thinking_token_budget"] - state["think_count"] + ) + state["check_count_down"] = max(0, remaining_budget - 1) + else: + state["check_count_down"] = state["thinking_token_budget"] + + # Check if need to transition to end mode + if ( + state["in_think"] + and state["think_count"] >= state["thinking_token_budget"] + ): + state["in_think"] = False + state["in_end"] = True + state["end_count"] = 0 + state["check_count_down"] = state["thinking_token_budget"] + else: + # In end mode + state["end_count"] += 1 + if state["end_count"] >= len(self.think_end_token_ids): + state.update( + { + "in_end": False, + "end_count": 0, + "check_count_down": state["thinking_token_budget"], + } + ) + + def is_argmax_invariant(self) -> bool: + """This logits processor can change the outcome of + greedy sampling by forcing that the thinking section + ends after a certain number of tokens.""" + return False + + def update_state(self, batch_update: BatchUpdate | None): + if not self.is_enabled: + return + if batch_update: + for index, params, prompt_tok_ids, output_tok_ids in batch_update.added: + thinking_token_budget = params.thinking_token_budget + + if thinking_token_budget is not None: + self._state[index] = self._init_state_entry( + prompt_tok_ids, thinking_token_budget + ) + self._state[index]["output_tok_ids"] = output_tok_ids + else: + # Remove state if no thinking budget + self._state.pop(index, None) + + for index in batch_update.removed: + self._state.pop(index, {}) + + for i1, i2, direction in batch_update.moved: + if direction == MoveDirectionality.SWAP: + state1 = self._state.pop(i1, None) + state2 = self._state.pop(i2, None) + if state1 is not None: + self._state[i2] = state1 + if state2 is not None: + self._state[i1] = state2 + else: + state = self._state.pop(i1, None) + if state is not None: + self._state[i2] = state + + for state in self._state.values(): + self._update_think_state(state) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if not self.is_enabled or not self._state: + return logits + + batch_size = logits.size(0) + self.mask[:batch_size] = False + + for i in range(batch_size): + state = self._state.get(i) + if state and state["in_end"]: + self.mask[i] = True + self.force_token_ids[i] = self.think_end_token_ids[state["end_count"]] + + # Check in CPU first not to sync with GPU + has_active_thinking = any( + state.get("in_end", False) for state in self._state.values() + ) + + if has_active_thinking: + current_mask = self.mask[:batch_size] + active_indices = current_mask.nonzero(as_tuple=False).view(-1) + if len(active_indices) > 0: + force_tokens = self.force_token_ids[active_indices] + # Apply a large value for the end thinking token id index + logits[active_indices, force_tokens] = 1e9 + + return logits + + def process_dict_updates( req_entries: dict[int, T], batch_update: BatchUpdate | None, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7c6248b37..6465ca654 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -629,7 +629,10 @@ class GPUModelRunner( ), # We currently don't know whether a particular custom logits processor # uses output token ids so we set this conservatively. - logitsprocs_need_output_token_ids=bool(custom_logitsprocs), + # ThinkingTokenBudgetLogitsProcessor also needs output token ids to + # correctly track think start/end token sequences in async scheduling. + logitsprocs_need_output_token_ids=bool(custom_logitsprocs) + or self.vllm_config.reasoning_config is not None, is_pooling_model=self.is_pooling_model, cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, )