diff --git a/docs/features/reasoning_outputs.md b/docs/features/reasoning_outputs.md
index cd66863a1..4a9f279e0 100644
--- a/docs/features/reasoning_outputs.md
+++ b/docs/features/reasoning_outputs.md
@@ -240,6 +240,81 @@ response = client.chat.completions.create(
)
```
+## Thinking Budget Control
+
+Some models, such as [Qwen3](https://qwen.readthedocs.io/en/latest/getting_started/quickstart.html#thinking-budget), [DeepSeek](https://www.alibabacloud.com/help/en/model-studio/deep-thinking), and [Nemotron3](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16), support a thinking budget that limits the maximum number of tokens used for reasoning.
+
+Token counting starts from `think_start_str`. Once the reasoning token count reaches the configured `thinking_token_budget`, vLLM forces the model to produce `think_end_str`, effectively terminating the reasoning block.
+
+To use this feature:
+
+- `--reasoning-parser` enables reasoning extraction.
+- `--reasoning-config` defines the reasoning boundary tokens (e.g., `think_start_str`, `think_end_str`).
+- `thinking_token_budget` (a sampling parameter) sets the per-request reasoning token limit.
+
+If `thinking_token_budget` is not specified, no explicit reasoning limit is applied beyond normal generation constraints such as `max_tokens`.
+
+`--reasoning-config` accepts a JSON object corresponding to
+[ReasoningConfig][vllm.config.ReasoningConfig] with the following fields:
+
+| Field | Type | Description |
+|-------------------|----------------|--------------------------------------------------|
+| `think_start_str` | `str \| null` | String that marks the start of reasoning content |
+| `think_end_str` | `str \| null` | String that marks the end of reasoning content |
+
+!!! note
+ `think_end_str` can include a transition phrase before the think end token. For example, setting `think_end_str` to `"I have to give the solution based on the thinking directly now."` instructs the model to emit that phrase when the budget is exhausted, making the reasoning termination more natural.
+
+### Online Serving
+
+```bash
+vllm serve Qwen/Qwen3-0.6B \
+ --reasoning-parser qwen3 \
+ --reasoning-config '{"think_start_str": "", "think_end_str": "I have to give the solution based on the thinking directly now."}'
+```
+
+Then make a request with `thinking_token_budget` to limit the reasoning tokens:
+
+```bash
+curl http://localhost:8000/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "Qwen/Qwen3-0.6B",
+ "messages": [
+ { "role": "user", "content": "9.11 and 9.8, which is greater?" }
+ ],
+ "extra_body": {
+ "thinking_token_budget": 10
+ }
+ }'
+```
+
+### Offline Inference
+
+```python
+from vllm import LLM, SamplingParams
+from vllm.config import ReasoningConfig
+
+llm = LLM(
+ model="Qwen/Qwen3-0.6B",
+ reasoning_config=ReasoningConfig(
+ think_start_str="",
+ think_end_str="I have to give the solution based on the thinking directly now.",
+ ),
+)
+
+sampling_params = SamplingParams(thinking_token_budget=10)
+
+messages = [
+ {"role": "user", "content": "9.11 and 9.8, which is greater?"},
+]
+
+outputs = llm.chat(messages, sampling_params=sampling_params)
+
+for output in outputs:
+ print("text:", output.outputs[0].text)
+```
+
## Limitations
- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
diff --git a/tests/v1/entrypoints/openai/test_thinking_token_budget.py b/tests/v1/entrypoints/openai/test_thinking_token_budget.py
new file mode 100644
index 000000000..f574b07b6
--- /dev/null
+++ b/tests/v1/entrypoints/openai/test_thinking_token_budget.py
@@ -0,0 +1,87 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+"""E2E tests for thinking_token_budget with reasoning models."""
+
+import openai
+import pytest
+import pytest_asyncio
+
+from tests.utils import RemoteOpenAIServer
+
+MODEL_NAME = "Qwen/Qwen3-0.6B"
+MESSAGES = [{"role": "user", "content": "What is 1+1? Be concise."}]
+THINK_BUDGET = 5
+
+
+@pytest.fixture(scope="module")
+def server():
+ args = [
+ "--reasoning-parser",
+ "qwen3",
+ "--reasoning-config",
+ '{"think_start_str": "", "think_end_str": ""}',
+ "--max-model-len",
+ "2048",
+ "--enforce-eager",
+ "--no-async-scheduling",
+ ]
+ with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
+ yield remote_server
+
+
+@pytest_asyncio.fixture
+async def client(server):
+ async with server.get_async_client() as async_client:
+ yield async_client
+
+
+@pytest.mark.asyncio
+async def test_thinking_token_budget_mixed_requests(client: openai.AsyncOpenAI):
+ """Test that mixed requests (some with thinking_token_budget, some without)
+ complete successfully without errors."""
+
+ response_with_budget = await client.chat.completions.create(
+ model=MODEL_NAME,
+ messages=MESSAGES,
+ max_tokens=100,
+ extra_body={"thinking_token_budget": THINK_BUDGET},
+ )
+ response_without_budget = await client.chat.completions.create(
+ model=MODEL_NAME,
+ messages=MESSAGES,
+ max_tokens=100,
+ )
+
+ msg_with = response_with_budget.choices[0].message
+ msg_without = response_without_budget.choices[0].message
+
+ assert msg_with.content or getattr(msg_with, "reasoning", None)
+ assert msg_without.content or getattr(msg_without, "reasoning", None)
+
+
+@pytest.mark.asyncio
+async def test_thinking_token_budget_limits_reasoning(client: openai.AsyncOpenAI):
+ """Test that thinking_token_budget limits the number of reasoning tokens.
+
+ In streaming mode each reasoning delta corresponds to one token, so
+ counting non-empty reasoning_content chunks gives the exact token count.
+ """
+
+ reasoning_token_count = 0
+ stream = await client.chat.completions.create(
+ model=MODEL_NAME,
+ messages=MESSAGES,
+ max_tokens=100,
+ stream=True,
+ extra_body={"thinking_token_budget": THINK_BUDGET},
+ )
+ async for chunk in stream:
+ delta = chunk.choices[0].delta
+ if getattr(delta, "reasoning", None):
+ reasoning_token_count += 1
+
+ assert reasoning_token_count == THINK_BUDGET, (
+ f"reasoning tokens ({reasoning_token_count}) != "
+ f"thinking_token_budget ({THINK_BUDGET})"
+ )
diff --git a/tests/v1/logits_processors/test_correctness.py b/tests/v1/logits_processors/test_correctness.py
index dac7ffed6..792168877 100644
--- a/tests/v1/logits_processors/test_correctness.py
+++ b/tests/v1/logits_processors/test_correctness.py
@@ -30,6 +30,7 @@ from vllm.v1.sample.logits_processor import (
MinPLogitsProcessor,
MinTokensLogitsProcessor,
MoveDirectionality,
+ ThinkingTokenBudgetLogitsProcessor,
build_logitsprocs,
)
from vllm.v1.sample.metadata import SamplingMetadata
@@ -47,6 +48,11 @@ MIN_TOKENS_LEN_THRESHOLD = 5
REQS_PER_LOGITPROC = 50
STR_NO_LOGITPROC = "none"
+# ThinkingTokenBudgetLogitsProcessor testing constants
+THINKING_TOKEN_BUDGET = 5
+THINK_START_TOKEN_ID = 999
+THINK_END_TOKEN_ID = 998
+
# LogitsProcessor subclass or "none"
LogitprocType: TypeAlias = type[LogitsProcessor] | str
@@ -67,9 +73,24 @@ class LogitsProcsRequestParams:
self.workload_index = workload_index
self.logitproc_type = logitproc_type
# Number of output tokens is randomly 0 or twice the min-tokens
- # threshold which will be used in testing. Output token values
- # don't matter *for these tests* so use 0 as a dummy value
- self.out_tokens = [0] * (MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2))
+ # threshold which will be used in testing.
+ # Generate diverse random tokens for all processors (more realistic)
+ num_tokens = MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2)
+ if num_tokens > 0:
+ # Use diverse random tokens
+ self.out_tokens = [random.randint(1, 950) for _ in range(num_tokens)]
+ # Set first token for ThinkingTokenBudget testing
+ is_thinking_processor = (
+ logitproc_type is ThinkingTokenBudgetLogitsProcessor
+ or (
+ hasattr(logitproc_type, "__name__")
+ and logitproc_type.__name__ == "ThinkingTokenBudgetLogitsProcessor"
+ )
+ )
+ if is_thinking_processor:
+ self.out_tokens[0] = THINK_START_TOKEN_ID
+ else:
+ self.out_tokens = []
self.prompt_tokens = []
self.params = _sampling_params_from_logitproc(logitproc_type)
@@ -79,6 +100,13 @@ class LogitsProcsRequestParams:
return f"MyClass({summ})"
+class MockReasoningConfig:
+ """Mock reasoning config for testing ThinkingTokenBudgetLogitsProcessor."""
+
+ think_start_token_ids = [THINK_START_TOKEN_ID]
+ think_end_token_ids = [THINK_END_TOKEN_ID]
+
+
def _generate_fake_sampling_metadata(
num_output_tokens: int,
batch_size: int,
@@ -97,8 +125,12 @@ def _generate_fake_sampling_metadata(
0, vocab_size, size=np.random.randint(1, MAX_NUM_PROMPT_TOKENS)
).tolist()
)
+
+ vllm_config = VllmConfig()
+ vllm_config.reasoning_config = MockReasoningConfig()
+
logitsprocs = build_logitsprocs(
- vllm_config=VllmConfig(),
+ vllm_config=vllm_config,
device=device,
is_pin_memory=PIN_MEMORY_AVAILABLE,
is_pooling_model=False,
@@ -403,6 +435,127 @@ def _min_tokens_validate(
)
+def _thinking_budget_params(kwargs: dict) -> None:
+ """Set SamplingParams kwargs for thinking token budget tests"""
+ kwargs["thinking_token_budget"] = THINKING_TOKEN_BUDGET
+
+
+def _thinking_budget_validate(
+ test_fakes: LogitsprocsTestFakes,
+ persistent_batch: list[LogitsProcsRequestParams],
+ logits_new: torch.Tensor,
+ batch_index: int,
+ request_params: LogitsProcsRequestParams,
+ step_idx: int,
+) -> None:
+ """Validate thinking token budget processor behavior"""
+ # Get the ThinkingTokenBudgetLogitsProcessor instance
+ tb_processor: ThinkingTokenBudgetLogitsProcessor = next(
+ test_fakes.get_logitsprocs_by_cls(ThinkingTokenBudgetLogitsProcessor)
+ )
+
+ # Get current request state
+ state = tb_processor._state.get(batch_index)
+ params = request_params.params
+
+ # Validate thinking token budget configuration
+ if hasattr(params, "thinking_token_budget") and params.thinking_token_budget:
+ # State should exist for requests with thinking_token_budget
+ if state is None:
+ _raise_error_invalid(
+ msg_suffix=(
+ f"Expected state for batch {batch_index} "
+ f"with thinking_token_budget={params.thinking_token_budget}"
+ ),
+ batch_index=batch_index,
+ request_params=request_params,
+ step_idx=step_idx,
+ )
+
+ # Validate budget matches what was set
+ expected_budget = params.thinking_token_budget
+ actual_budget = state["thinking_token_budget"]
+
+ if actual_budget != expected_budget:
+ _raise_error_invalid(
+ msg_suffix=(
+ f"Budget mismatch: expected {expected_budget}, got {actual_budget}"
+ ),
+ batch_index=batch_index,
+ request_params=request_params,
+ step_idx=step_idx,
+ )
+
+ # Check if we're in thinking mode and validate token counting
+ output_tokens = request_params.out_tokens
+
+ # Find if thinking has started in output tokens
+ thinking_started = False
+ start_tokens = tb_processor.think_start_token_ids
+
+ if len(start_tokens) > 0:
+ for i in range(len(output_tokens) - len(start_tokens) + 1):
+ if output_tokens[i : i + len(start_tokens)] == start_tokens:
+ thinking_started = True
+ break
+
+ if thinking_started:
+ # If budget is exceeded, validate end token forcing
+ think_count = state["think_count"]
+ budget = state["thinking_token_budget"]
+
+ if think_count >= budget:
+ if not state["in_end"]:
+ _raise_error_invalid(
+ msg_suffix=(
+ f"Budget exceeded ({think_count} >= "
+ f"{budget}) but not "
+ "forcing end tokens"
+ ),
+ batch_index=batch_index,
+ request_params=request_params,
+ step_idx=step_idx,
+ )
+
+ # Validate that only end tokens are allowed
+ end_tokens = tb_processor.think_end_token_ids
+ if len(end_tokens) > 0:
+ expected_end_token_id = end_tokens[
+ min(state["end_count"], len(end_tokens) - 1)
+ ]
+
+ # Check logits masking
+ batch_logits = logits_new[batch_index]
+ for token_id in range(len(batch_logits)):
+ logit_value = batch_logits[token_id]
+
+ if token_id == expected_end_token_id:
+ # End token should not be masked
+ if logit_value == -float("inf"):
+ _raise_error_invalid(
+ msg_suffix=(
+ f"End token {token_id} should not be "
+ "masked but is"
+ ),
+ batch_index=batch_index,
+ request_params=request_params,
+ step_idx=step_idx,
+ )
+ else:
+ # All other tokens should be masked when forcing end
+ if logit_value != -float("inf"):
+ _raise_error_invalid(
+ msg_suffix=(
+ f"Token {token_id} should be masked "
+ f"when forcing end tokens, but "
+ f"logit={logit_value}"
+ ),
+ batch_index=batch_index,
+ request_params=request_params,
+ step_idx=step_idx,
+ )
+
+
def _none_validate(
test_fakes: LogitsprocsTestFakes,
persistent_batch: list[LogitsProcsRequestParams],
@@ -449,20 +602,30 @@ logitsprocs_test_mapping = {
MinTokensLogitsProcessor: LogitsprocTestHelpers(
gen_request_fxn=_min_tokens_params, eval_fxn=_min_tokens_validate
),
+ ThinkingTokenBudgetLogitsProcessor: LogitsprocTestHelpers(
+ gen_request_fxn=_thinking_budget_params, eval_fxn=_thinking_budget_validate
+ ),
}
def _get_test_cases() -> list[list[str]]:
"""Each test case is a set of logitsprocs"""
logitsprocs_types = list(logitsprocs_test_mapping.keys())
+
+ # Isolate ThinkingTokenBudgetLogitsProcessor from all other processors
+ # to avoid unexpected modification of logits interference
+ thinking_processor = ThinkingTokenBudgetLogitsProcessor
+ other_processors = [
+ p
+ for p in logitsprocs_types
+ if p != STR_NO_LOGITPROC and p != thinking_processor
+ ]
+
return (
[[STR_NO_LOGITPROC]]
- + [
- [logitproc_type, STR_NO_LOGITPROC]
- for logitproc_type in logitsprocs_types
- if logitproc_type != STR_NO_LOGITPROC
- ]
- + [logitsprocs_types]
+ + [[logitproc_type, STR_NO_LOGITPROC] for logitproc_type in other_processors]
+ + [other_processors]
+ + [[thinking_processor]]
)
diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py
index 452fb0466..d5a3e9bfd 100644
--- a/vllm/config/__init__.py
+++ b/vllm/config/__init__.py
@@ -33,6 +33,7 @@ from vllm.config.offload import (
from vllm.config.parallel import EPLBConfig, ParallelConfig
from vllm.config.pooler import PoolerConfig
from vllm.config.profiler import ProfilerConfig
+from vllm.config.reasoning import ReasoningConfig
from vllm.config.scheduler import SchedulerConfig
from vllm.config.speculative import SpeculativeConfig
from vllm.config.speech_to_text import SpeechToTextConfig
@@ -101,6 +102,8 @@ __all__ = [
"ParallelConfig",
# From vllm.config.pooler
"PoolerConfig",
+ # From vllm.config.reasoning
+ "ReasoningConfig",
# From vllm.config.scheduler
"SchedulerConfig",
# From vllm.config.speculative
diff --git a/vllm/config/reasoning.py b/vllm/config/reasoning.py
new file mode 100644
index 000000000..872e05580
--- /dev/null
+++ b/vllm/config/reasoning.py
@@ -0,0 +1,72 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from dataclasses import field
+
+from vllm.config.model import ModelConfig
+from vllm.config.utils import config
+from vllm.tokenizers import cached_tokenizer_from_config
+
+
+@config
+class ReasoningConfig:
+ """Configuration for reasoning models.
+
+ Set `think_start_str` and `think_end_str` to the strings that delimit
+ the reasoning block (e.g. `""` and `""`). The
+ corresponding token IDs are derived automatically via
+ `initialize_token_ids` and are not intended to be set directly.
+ """
+
+ # NOTE: These parameters are temporary, the intent is to derive them
+ # automatically from the reasoning parser in a future version.
+ think_start_str: str = ""
+ """String that indicates the start of reasoning."""
+ think_end_str: str = ""
+ """String that indicates the end of reasoning content."""
+
+ _think_start_token_ids: list[int] | None = field(
+ default=None, init=False, repr=False
+ )
+ """Private backing field for `think_start_token_ids`. Set by
+ `initialize_token_ids`. Not intended to be configured directly."""
+ _think_end_token_ids: list[int] | None = field(default=None, init=False, repr=False)
+ """Private backing field for `think_end_token_ids`. Set by
+ `initialize_token_ids`. Not intended to be configured directly."""
+
+ @property
+ def think_start_token_ids(self) -> list[int] | None:
+ """Token IDs derived from `think_start_str`. Set automatically by
+ `initialize_token_ids`. Not intended to be configured directly."""
+ return self._think_start_token_ids
+
+ @property
+ def think_end_token_ids(self) -> list[int] | None:
+ """Token IDs derived from `think_end_str`. Set automatically by
+ `initialize_token_ids`. Not intended to be configured directly."""
+ return self._think_end_token_ids
+
+ def initialize_token_ids(self, model_config: ModelConfig) -> None:
+ """Initialize reasoning token IDs from strings using the tokenizer."""
+ if (
+ self._think_start_token_ids is not None
+ and self._think_end_token_ids is not None
+ ):
+ return
+
+ tokenizer = cached_tokenizer_from_config(model_config=model_config)
+
+ self._think_start_token_ids = tokenizer.encode(
+ self.think_start_str, add_special_tokens=False
+ )
+ self._think_end_token_ids = tokenizer.encode(
+ self.think_end_str, add_special_tokens=False
+ )
+
+ if not self._think_start_token_ids or not self._think_end_token_ids:
+ raise ValueError(
+ f"ReasoningConfig: failed to tokenize reasoning strings: "
+ f"think_start_str='{self.think_start_str}', "
+ f"think_end_str='{self.think_end_str}'. "
+ "Ensure the strings are valid tokens in the model's vocabulary."
+ )
diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py
index a8088de63..c7449840d 100644
--- a/vllm/config/vllm.py
+++ b/vllm/config/vllm.py
@@ -40,6 +40,7 @@ from .observability import ObservabilityConfig
from .offload import OffloadConfig
from .parallel import ParallelConfig
from .profiler import ProfilerConfig
+from .reasoning import ReasoningConfig
from .scheduler import SchedulerConfig
from .speculative import EagleModelTypes, NgramGPUTypes, SpeculativeConfig
from .structured_outputs import StructuredOutputsConfig
@@ -302,6 +303,8 @@ class VllmConfig: # type: ignore[misc]
"""The configurations for event publishing."""
ec_transfer_config: ECTransferConfig | None = None
"""The configurations for distributed EC cache transfer."""
+ reasoning_config: ReasoningConfig | None = None
+ """The configurations for reasoning model."""
# some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing, debugging or out of
# tree config registration.
@@ -1143,6 +1146,9 @@ class VllmConfig: # type: ignore[misc]
if not self.instance_id:
self.instance_id = random_uuid()[:5]
+ if self.reasoning_config is not None and self.model_config is not None:
+ self.reasoning_config.initialize_token_ids(self.model_config)
+
# Hybrid KV cache manager (HMA) runtime rules:
# - Explicit enable (--no-disable-kv-cache-manager): error if runtime
# disables it
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index e0d5236bc..b7276a345 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -53,6 +53,7 @@ from vllm.config import (
PoolerConfig,
PrefetchOffloadConfig,
ProfilerConfig,
+ ReasoningConfig,
SchedulerConfig,
SpeculativeConfig,
StructuredOutputsConfig,
@@ -581,6 +582,7 @@ class EngineArgs:
kv_events_config: KVEventsConfig | None = None
ec_transfer_config: ECTransferConfig | None = None
+ reasoning_config: ReasoningConfig = get_field(VllmConfig, "reasoning_config")
generation_config: str = ModelConfig.generation_config
enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
@@ -1297,6 +1299,7 @@ class EngineArgs:
vllm_group.add_argument(
"--attention-config", "-ac", **vllm_kwargs["attention_config"]
)
+ vllm_group.add_argument("--reasoning-config", **vllm_kwargs["reasoning_config"])
vllm_group.add_argument("--kernel-config", **vllm_kwargs["kernel_config"])
vllm_group.add_argument(
"--additional-config", **vllm_kwargs["additional_config"]
@@ -1958,6 +1961,7 @@ class EngineArgs:
kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config,
ec_transfer_config=self.ec_transfer_config,
+ reasoning_config=self.reasoning_config,
profiler_config=self.profiler_config,
additional_config=self.additional_config,
optimization_level=self.optimization_level,
diff --git a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py
index 61763a3b6..09ef448f4 100644
--- a/vllm/entrypoints/openai/chat_completion/protocol.py
+++ b/vllm/entrypoints/openai/chat_completion/protocol.py
@@ -180,6 +180,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
| None
) = "none"
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None
+ thinking_token_budget: int | None = None
include_reasoning: bool = True
parallel_tool_calls: bool | None = True
@@ -515,6 +516,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
structured_outputs=self.structured_outputs,
logit_bias=self.logit_bias,
bad_words=self.bad_words,
+ thinking_token_budget=self.thinking_token_budget,
allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None,
skip_clone=True, # Created fresh per request, safe to skip clone
diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py
index f7a2e8b3f..dc3e1d49c 100644
--- a/vllm/sampling_params.py
+++ b/vllm/sampling_params.py
@@ -281,6 +281,8 @@ class SamplingParams(
_bad_words_token_ids: list[list[int]] | None = None
skip_reading_prefix_cache: bool | None = None
+ thinking_token_budget: int | None = None
+ """Maximum number of tokens allowed for thinking operations."""
repetition_detection: RepetitionDetectionParams | None = None
"""Parameters for detecting repetitive N-gram patterns in output tokens.
@@ -304,6 +306,7 @@ class SamplingParams(
stop: str | list[str] | None = None,
stop_token_ids: list[int] | None = None,
bad_words: list[str] | None = None,
+ thinking_token_budget: int | None = None,
include_stop_str_in_output: bool = False,
ignore_eos: bool = False,
max_tokens: int | None = 16,
@@ -344,6 +347,7 @@ class SamplingParams(
stop=stop,
stop_token_ids=stop_token_ids,
bad_words=bad_words,
+ thinking_token_budget=thinking_token_budget,
include_stop_str_in_output=include_stop_str_in_output,
ignore_eos=ignore_eos,
max_tokens=max_tokens,
@@ -858,6 +862,7 @@ class SamplingParams(
f"stop={self.stop}, "
f"stop_token_ids={self.stop_token_ids}, "
f"bad_words={self.bad_words}, "
+ f"thinking_token_budget={self.thinking_token_budget}, "
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, "
diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py
index aab560544..b77b9277a 100644
--- a/vllm/v1/engine/input_processor.py
+++ b/vllm/v1/engine/input_processor.py
@@ -99,6 +99,16 @@ class InputProcessor:
self.structured_outputs_config,
self.tokenizer,
)
+
+ if (
+ params.thinking_token_budget is not None
+ and self.vllm_config.reasoning_config is None
+ ):
+ raise ValueError(
+ "thinking_token_budget is set but reasoning_config is "
+ "not configured. Please set --reasoning-config to use "
+ "thinking_token_budget."
+ )
elif isinstance(params, PoolingParams):
supported_pooling_tasks = [
task for task in supported_tasks if task in POOLING_TASKS
diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py
index 2cb89e1ea..fb4a046fc 100644
--- a/vllm/v1/sample/logits_processor/__init__.py
+++ b/vllm/v1/sample/logits_processor/__init__.py
@@ -18,6 +18,7 @@ from vllm.v1.sample.logits_processor.builtin import (
LogitBiasLogitsProcessor,
MinPLogitsProcessor,
MinTokensLogitsProcessor,
+ ThinkingTokenBudgetLogitsProcessor,
process_dict_updates,
)
from vllm.v1.sample.logits_processor.interface import (
@@ -50,6 +51,7 @@ BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [
MinTokensLogitsProcessor,
LogitBiasLogitsProcessor,
MinPLogitsProcessor,
+ ThinkingTokenBudgetLogitsProcessor,
]
@@ -354,4 +356,5 @@ __all__ = [
"STR_POOLING_REJECTS_LOGITSPROCS",
"LOGITSPROCS_GROUP",
"AdapterLogitsProcessor",
+ "ThinkingTokenBudgetLogitsProcessor",
]
diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py
index 11a52711d..c92f33402 100644
--- a/vllm/v1/sample/logits_processor/builtin.py
+++ b/vllm/v1/sample/logits_processor/builtin.py
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Sequence
-from typing import TYPE_CHECKING, TypeVar
+from typing import TYPE_CHECKING, Any, TypeVar
import numpy as np
import torch
@@ -291,6 +291,263 @@ class MinTokensLogitsProcessor(LogitsProcessor):
return logits
+class ThinkingTokenBudgetLogitsProcessor(LogitsProcessor):
+ """Limits the number of tokens allowed inside a 'thinking' section."""
+
+ def __init__(
+ self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
+ ):
+ reasoning_config = vllm_config.reasoning_config
+ max_num_reqs = vllm_config.scheduler_config.max_num_seqs
+
+ # Check if thinking is enabled
+ self.is_enabled = reasoning_config is not None
+
+ self.think_start_token_ids = getattr(
+ reasoning_config, "think_start_token_ids", []
+ )
+ self.think_end_token_ids = getattr(reasoning_config, "think_end_token_ids", [])
+
+ self.pin_memory = is_pin_memory
+ self.device = device
+ # Per-request state tracking for thinking token management
+ # Key: request_index, Value: state dict containing:
+ # "in_think": bool - currently in thinking mode
+ # "in_end": bool - currently forcing end tokens output
+ # "check_count_down": int - steps remaining until next think
+ # start/end token parsing
+ # "think_count": int - number of thinking tokens generated
+ # "end_count": int - number of end tokens forced so far
+ # "thinking_token_budget": int - max allowed thinking tokens
+ # "output_tok_ids": list[int] - generated output tokens
+ # "prev_output_length": int - previous output length for
+ # incremental processing
+ self._state: dict[int, dict[str, Any]] = {}
+
+ # Preallocate reusable tensors
+ self.mask = torch.zeros(max_num_reqs, dtype=torch.bool, device=device)
+ self.force_token_ids = torch.full(
+ (max_num_reqs,), -1, dtype=torch.long, device=device
+ )
+
+ @staticmethod
+ def _find_last_sequence_index(target_list: list[int], token_ids: list[int]) -> int:
+ """
+ Returns the index of the last occurrence of token_ids in target_list.
+
+ Args:
+ target_list (list[int]): The list of token IDs.
+ token_ids (list[int]): The sequence of token IDs to find.
+ """
+ if not token_ids:
+ return -1
+ for i in range(len(target_list) - len(token_ids), -1, -1):
+ if target_list[i : i + len(token_ids)] == token_ids:
+ return i
+ return -1
+
+ def _init_state_entry(
+ self, prompt_tok_ids: list[int] | None, thinking_token_budget: int
+ ) -> dict[str, Any]:
+ """Initializes the tracking state for a given sequence index."""
+ if prompt_tok_ids is None:
+ last_start = -1
+ last_end = -1
+ in_think = False
+ think_count = 0
+ else:
+ last_start = self._find_last_sequence_index(
+ prompt_tok_ids, self.think_start_token_ids
+ )
+ last_end = self._find_last_sequence_index(
+ prompt_tok_ids, self.think_end_token_ids
+ )
+ in_think = last_start > last_end
+ if in_think:
+ think_count = len(prompt_tok_ids) - (
+ last_start + len(self.think_start_token_ids)
+ )
+ else:
+ think_count = 0
+
+ return {
+ "in_think": in_think, # Currently in thinking mode
+ "in_end": in_think and thinking_token_budget == 0,
+ "check_count_down": thinking_token_budget,
+ "think_count": think_count, # Number of tokens in thinking section
+ "end_count": 0, # Number of end tokens forced so far
+ "prompt_tok_ids": prompt_tok_ids,
+ "output_tok_ids": [],
+ "thinking_token_budget": thinking_token_budget,
+ "prev_output_length": 0,
+ # Track previous output length for incremental updates
+ }
+
+ def _update_think_state(self, state: dict[str, Any]):
+ """Updates the state based on newly generated output tokens."""
+ if not state.get("in_end", False) and state.get("check_count_down", 0) > 0:
+ state["check_count_down"] -= 1
+ return
+
+ output = state.get("output_tok_ids", [])
+ if not output:
+ return
+
+ # Track previous output length for incremental processing
+ prev_length = state.get("prev_output_length", 0)
+ current_length = len(output)
+
+ if current_length <= prev_length:
+ return
+
+ # Process only newly added tokens
+ new_tokens = output[prev_length:]
+ state["prev_output_length"] = current_length
+
+ # Check if new tokens contain think start or end sequences
+ start_len = len(self.think_start_token_ids)
+ end_len = len(self.think_end_token_ids)
+
+ # Look for think sequences in recent tokens (including boundary)
+ # Check overlapping regions where sequences might span boundaries
+ check_start_idx = max(0, prev_length - max(start_len, end_len) + 1)
+ recent_tokens = output[check_start_idx:]
+
+ # Find any think start/end sequences in recent tokens
+ recent_start_pos = self._find_last_sequence_index(
+ recent_tokens, self.think_start_token_ids
+ )
+ recent_end_pos = self._find_last_sequence_index(
+ recent_tokens, self.think_end_token_ids
+ )
+
+ # Update state based on recent sequences
+ if not state["in_end"]:
+ if recent_start_pos >= 0 and recent_end_pos >= 0:
+ if recent_start_pos > recent_end_pos:
+ # Case: ......... - entering think mode
+ absolute_start_pos = check_start_idx + recent_start_pos
+ new_think_count = current_length - (absolute_start_pos + start_len)
+ state["in_think"] = True
+ state["think_count"] = new_think_count
+ else:
+ # Case: ......... - exiting think mode
+ state["in_think"] = False
+ state["think_count"] = 0
+ elif recent_start_pos >= 0:
+ # Found think start - entering think mode
+ absolute_start_pos = check_start_idx + recent_start_pos
+ new_think_count = current_length - (absolute_start_pos + start_len)
+ state["in_think"] = True
+ state["think_count"] = new_think_count
+ elif recent_end_pos >= 0:
+ # Found think end - exiting think mode
+ state["in_think"] = False
+ state["think_count"] = 0
+ elif state["in_think"]:
+ # Continue thinking mode, increment count by new tokens
+ state["think_count"] += len(new_tokens)
+
+ # Set countdown based on current state
+ if state["in_think"]:
+ remaining_budget = max(
+ 0, state["thinking_token_budget"] - state["think_count"]
+ )
+ state["check_count_down"] = max(0, remaining_budget - 1)
+ else:
+ state["check_count_down"] = state["thinking_token_budget"]
+
+ # Check if need to transition to end mode
+ if (
+ state["in_think"]
+ and state["think_count"] >= state["thinking_token_budget"]
+ ):
+ state["in_think"] = False
+ state["in_end"] = True
+ state["end_count"] = 0
+ state["check_count_down"] = state["thinking_token_budget"]
+ else:
+ # In end mode
+ state["end_count"] += 1
+ if state["end_count"] >= len(self.think_end_token_ids):
+ state.update(
+ {
+ "in_end": False,
+ "end_count": 0,
+ "check_count_down": state["thinking_token_budget"],
+ }
+ )
+
+ def is_argmax_invariant(self) -> bool:
+ """This logits processor can change the outcome of
+ greedy sampling by forcing that the thinking section
+ ends after a certain number of tokens."""
+ return False
+
+ def update_state(self, batch_update: BatchUpdate | None):
+ if not self.is_enabled:
+ return
+ if batch_update:
+ for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
+ thinking_token_budget = params.thinking_token_budget
+
+ if thinking_token_budget is not None:
+ self._state[index] = self._init_state_entry(
+ prompt_tok_ids, thinking_token_budget
+ )
+ self._state[index]["output_tok_ids"] = output_tok_ids
+ else:
+ # Remove state if no thinking budget
+ self._state.pop(index, None)
+
+ for index in batch_update.removed:
+ self._state.pop(index, {})
+
+ for i1, i2, direction in batch_update.moved:
+ if direction == MoveDirectionality.SWAP:
+ state1 = self._state.pop(i1, None)
+ state2 = self._state.pop(i2, None)
+ if state1 is not None:
+ self._state[i2] = state1
+ if state2 is not None:
+ self._state[i1] = state2
+ else:
+ state = self._state.pop(i1, None)
+ if state is not None:
+ self._state[i2] = state
+
+ for state in self._state.values():
+ self._update_think_state(state)
+
+ def apply(self, logits: torch.Tensor) -> torch.Tensor:
+ if not self.is_enabled or not self._state:
+ return logits
+
+ batch_size = logits.size(0)
+ self.mask[:batch_size] = False
+
+ for i in range(batch_size):
+ state = self._state.get(i)
+ if state and state["in_end"]:
+ self.mask[i] = True
+ self.force_token_ids[i] = self.think_end_token_ids[state["end_count"]]
+
+ # Check in CPU first not to sync with GPU
+ has_active_thinking = any(
+ state.get("in_end", False) for state in self._state.values()
+ )
+
+ if has_active_thinking:
+ current_mask = self.mask[:batch_size]
+ active_indices = current_mask.nonzero(as_tuple=False).view(-1)
+ if len(active_indices) > 0:
+ force_tokens = self.force_token_ids[active_indices]
+ # Apply a large value for the end thinking token id index
+ logits[active_indices, force_tokens] = 1e9
+
+ return logits
+
+
def process_dict_updates(
req_entries: dict[int, T],
batch_update: BatchUpdate | None,
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index 7c6248b37..6465ca654 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -629,7 +629,10 @@ class GPUModelRunner(
),
# We currently don't know whether a particular custom logits processor
# uses output token ids so we set this conservatively.
- logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
+ # ThinkingTokenBudgetLogitsProcessor also needs output token ids to
+ # correctly track think start/end token sequences in async scheduling.
+ logitsprocs_need_output_token_ids=bool(custom_logitsprocs)
+ or self.vllm_config.reasoning_config is not None,
is_pooling_model=self.is_pooling_model,
cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
)