Compare commits

...

13 Commits

Author SHA1 Message Date
Michael
2a69949bda [Bugfix]: Fix Gemma4ToolParser.__init__() missing tools parameter (#38847)
Signed-off-by: Michael Hospedales <hospedales@me.com>
(cherry picked from commit bb39382b2b)
2026-04-02 16:45:38 -07:00
Luciano Martins
8adcf8c40a feat(models): implement Google Gemma 4 architecture support (MoE, Multimodal, Reasoning, Tool-Use) (#38826)
Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Signed-off-by: Luciano Martins <lucianomartins@google.com>
Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
(cherry picked from commit 08ed2b9688)
2026-04-02 11:49:53 -07:00
khluu
cfad6a509c Revert "[Bugfix] Restrict TRTLLM attention to SM100, fixing GB300 (SM103) hang (#38730)"
This reverts commit c284a6671c.
2026-04-01 15:14:58 -07:00
Stefano Castagnetta
c284a6671c [Bugfix] Restrict TRTLLM attention to SM100, fixing GB300 (SM103) hang (#38730)
Signed-off-by: Stefano Castagnetta <scastagnetta@nvidia.com>
(cherry picked from commit 6183cae1bd)
2026-04-01 12:11:03 -07:00
Chauncey
3a30a1a6a8 [Misc] Rename think_start_str/think_end_str to reasoning_start_str/reasoning_end_str (#38242)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
(cherry picked from commit cbe7d18096)
2026-04-01 12:10:53 -07:00
Juan Pérez de Algaba
29982d48b3 (security) Enforce frame limit in VideoMediaIO (#38636)
Signed-off-by: jperezde <jperezde@redhat.com>
(cherry picked from commit 58ee614221)
2026-04-01 12:10:40 -07:00
Yifan Qiao
1dbbafd3f3 [Feat][v1] Simple yet General CPU KV Cache Offloading (#37160)
Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
(cherry picked from commit 91e4521f9f)
2026-04-01 01:03:14 -07:00
Lucas Wilkinson
0ee3b7fc3d [Bugfix][MLA] Add logits size budget to sparse indexer prefill chunking (#36178)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
(cherry picked from commit eb47454987)
2026-04-01 01:02:58 -07:00
Matthew Bonanni
268bed9cf3 [Bugfix][Async] Fix async spec decoding with hybrid models (#38556)
Signed-off-by: SandishKumarHN <sandishkumarhn@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: SandishKumarHN <sandishkumarhn@gmail.com>
(cherry picked from commit 757068dc65)
2026-04-01 01:02:35 -07:00
Jiangyun Zhu
bcc0fdd0f3 [CI] fix LM Eval Qwen3.5 Models (B200) (#38632)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
(cherry picked from commit ea7bfde6e4)
2026-04-01 01:02:20 -07:00
wang.yuqi
69b8bd4b33 [CI Failure] pin colmodernvbert revision (#38612)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
(cherry picked from commit 719735d6c5)
2026-04-01 01:02:04 -07:00
Li, Jiang
12449f9492 [Bugfix][CPU] Skip set_num_threads after thread binding (#38535)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
(cherry picked from commit 6557f4937f)
2026-03-30 23:01:42 -07:00
haosdent
b92312dfd7 [CI] Fix SPLADE pooler test broken by #38139 (#38495)
Signed-off-by: haosdent <haosdent@gmail.com>
(cherry picked from commit a08b7733fd)
2026-03-30 21:52:13 -07:00
56 changed files with 8537 additions and 136 deletions

View File

@@ -1,9 +1,10 @@
#!/bin/bash
set -euox pipefail
export VLLM_CPU_CI_ENV=0
export VLLM_CPU_KVCACHE_SPACE=1 # avoid OOM
echo "--- PP+TP"
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 &
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 --max-model-len=4096 &
server_pid=$!
timeout 600 bash -c "until curl localhost:8000/v1/models > /dev/null 2>&1; do sleep 1; done" || exit 1
vllm bench serve \
@@ -23,7 +24,7 @@ if [ "$failed_req" -ne 0 ]; then
fi
echo "--- DP+TP"
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 &
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 --max-model-len=4096 &
server_pid=$!
timeout 600 bash -c "until curl localhost:8000/v1/models > /dev/null 2>&1; do sleep 1; done" || exit 1
vllm bench serve \

View File

@@ -244,12 +244,12 @@ response = client.chat.completions.create(
Some models, such as [Qwen3](https://qwen.readthedocs.io/en/latest/getting_started/quickstart.html#thinking-budget), [DeepSeek](https://www.alibabacloud.com/help/en/model-studio/deep-thinking), and [Nemotron3](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16), support a thinking budget that limits the maximum number of tokens used for reasoning.
Token counting starts from `think_start_str`. Once the reasoning token count reaches the configured `thinking_token_budget`, vLLM forces the model to produce `think_end_str`, effectively terminating the reasoning block.
Token counting starts from `reasoning_start_str`. Once the reasoning token count reaches the configured `thinking_token_budget`, vLLM forces the model to produce `reasoning_end_str`, effectively terminating the reasoning block.
To use this feature:
- `--reasoning-parser` enables reasoning extraction.
- `--reasoning-config` defines the reasoning boundary tokens (e.g., `think_start_str`, `think_end_str`).
- `--reasoning-config` defines the reasoning boundary tokens (e.g., `reasoning_start_str`, `reasoning_end_str`).
- `thinking_token_budget` (a sampling parameter) sets the per-request reasoning token limit.
If `thinking_token_budget` is not specified, no explicit reasoning limit is applied beyond normal generation constraints such as `max_tokens`.
@@ -257,20 +257,20 @@ If `thinking_token_budget` is not specified, no explicit reasoning limit is appl
`--reasoning-config` accepts a JSON object corresponding to
[ReasoningConfig][vllm.config.ReasoningConfig] with the following fields:
| Field | Type | Description |
|-------------------|----------------|--------------------------------------------------|
| `think_start_str` | `str \| null` | String that marks the start of reasoning content |
| `think_end_str` | `str \| null` | String that marks the end of reasoning content |
| Field | Type | Description |
|-----------------------|----------------|--------------------------------------------------|
| `reasoning_start_str` | `str \| null` | String that marks the start of reasoning content |
| `reasoning_end_str` | `str \| null` | String that marks the end of reasoning content |
!!! note
`think_end_str` can include a transition phrase before the think end token. For example, setting `think_end_str` to `"I have to give the solution based on the thinking directly now.</think>"` instructs the model to emit that phrase when the budget is exhausted, making the reasoning termination more natural.
`reasoning_end_str` can include a transition phrase before the reasoning end token. For example, setting `reasoning_end_str` to `"I have to give the solution based on the reasoning directly now.</think>"` instructs the model to emit that phrase when the budget is exhausted, making the reasoning termination more natural.
### Online Serving
```bash
vllm serve Qwen/Qwen3-0.6B \
--reasoning-parser qwen3 \
--reasoning-config '{"think_start_str": "<think>", "think_end_str": "I have to give the solution based on the thinking directly now.</think>"}'
--reasoning-config '{"reasoning_start_str": "<think>", "reasoning_end_str": "I have to give the solution based on the reasoning directly now.</think>"}'
```
Then make a request with `thinking_token_budget` to limit the reasoning tokens:
@@ -298,8 +298,8 @@ from vllm.config import ReasoningConfig
llm = LLM(
model="Qwen/Qwen3-0.6B",
reasoning_config=ReasoningConfig(
think_start_str="<think>",
think_end_str="I have to give the solution based on the thinking directly now.</think>",
reasoning_start_str="<think>",
reasoning_end_str="I have to give the solution based on the thinking directly now.</think>",
),
)

View File

@@ -7,3 +7,4 @@ server_args: >-
--max-model-len 4096
--data-parallel-size 2
--enable-expert-parallel
--max-num-seqs 512

View File

@@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import types
import pytest
import torch
import torch.nn as nn
@@ -11,6 +9,8 @@ from vllm.model_executor.models.bert import (
BertMLMHead,
SPLADESparsePooler,
)
from vllm.pooling_params import PoolingParams
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
# ---------------------------------------------------------------------
# Functional test: SPLADE formula correctness (no HF download needed)
@@ -38,8 +38,12 @@ def test_splade_pooler_matches_reference_formula(B, T, H, V):
],
dtype=torch.long,
)
meta = types.SimpleNamespace(
prompt_lens=prompt_lens_tenser, prompt_token_ids=token_ids
meta = PoolingMetadata(
prompt_lens=prompt_lens_tenser,
prompt_token_ids=token_ids,
prompt_token_ids_cpu=token_ids,
pooling_params=[PoolingParams(task="embed")] * B,
pooling_states=[PoolingStates() for _ in range(B)],
)
# MLM head (prefer BertMLMHead, fallback to Linear if unavailable)

View File

@@ -394,6 +394,22 @@ VLM_TEST_SETTINGS = {
vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}},
patch_hf_runner=model_utils.gemma3_patch_hf_runner,
),
"gemma4": VLMTestInfo(
models=["google/gemma-4-E2B-it"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n", # noqa: E501
single_image_prompts=IMAGE_ASSETS.prompts(
{
"stop_sign": "What's the content in the center of the image?",
"cherry_blossom": "What is the season?",
}
),
multi_image_prompt="Describe the two images in detail.",
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
vllm_runner_kwargs={"limit_mm_per_prompt": {"image": 4}},
),
"granite_vision": VLMTestInfo(
models=["ibm-granite/granite-vision-3.3-2b"],
test_type=(VLMTestType.IMAGE),

View File

@@ -15,6 +15,10 @@ from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
MODEL_NAME = "ModernVBERT/colmodernvbert-merged"
COLBERT_DIM = 128
DTYPE = "half"
# Fixme:
# Update colmodernvbert code to support the latest HF version
# and remove revision set.
REVISION = "4a0a9f3ac7a7992fec410bfa8e3d080ac9a5bcee"
# -----------------------------------------------------------------------
@@ -26,6 +30,7 @@ def test_colmodernvbert_text_token_embed(vllm_runner):
"""Text query produces per-token embeddings with shape (seq_len, 128)."""
with vllm_runner(
MODEL_NAME,
revision=REVISION,
runner="pooling",
dtype=DTYPE,
enforce_eager=True,
@@ -49,6 +54,7 @@ def test_colmodernvbert_text_relevance_ordering(vllm_runner):
with vllm_runner(
MODEL_NAME,
revision=REVISION,
runner="pooling",
dtype=DTYPE,
enforce_eager=True,
@@ -66,6 +72,7 @@ def test_colmodernvbert_text_late_interaction(vllm_runner):
with vllm_runner(
MODEL_NAME,
revision=REVISION,
runner="pooling",
dtype=DTYPE,
enforce_eager=True,
@@ -92,6 +99,7 @@ def test_colmodernvbert_image_token_embed(vllm_runner, image_assets):
"""Image input produces per-token embeddings including vision tokens."""
with vllm_runner(
MODEL_NAME,
revision=REVISION,
runner="pooling",
dtype=DTYPE,
enforce_eager=True,

View File

@@ -0,0 +1,44 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.multimodal import MULTIMODAL_REGISTRY
from ....conftest import ImageTestAssets
from ...utils import build_model_context
# TODO: to be updated to "google/gemma-4-e2b-it" once the models are available
GEMMA4_MODEL_ID = "google/gemma-4-E2B-it"
@pytest.mark.parametrize("model_id", [GEMMA4_MODEL_ID])
def test_limit_mm_per_prompt(
image_assets: ImageTestAssets,
model_id: str,
):
"""Test that limit_mm_per_prompt accurately restricts multiple images."""
# We only allow 1 image
ctx = build_model_context(
model_id,
mm_processor_kwargs={},
limit_mm_per_prompt={"image": 1},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
# Provide 2 images in the prompt
prompt = "<image><image>"
# image_assets usually has multiple images
images = [asset.pil_image for asset in image_assets][:2]
if len(images) < 2:
images = [images[0], images[0]]
mm_data = {"image": images}
# Expect ValueError when exceeding limit
with pytest.raises(ValueError, match="At most 1 image"):
processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={},
)

View File

@@ -277,6 +277,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"google/gemma-2-9b", extras={"tiny": "google/gemma-2-2b-it"}
),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
"Gemma4ForCausalLM": _HfExamplesInfo(
"google/gemma-4-E2B-it",
min_transformers_version="5.0.0",
),
"Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it"),
"GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"),
"Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"),
@@ -636,6 +640,7 @@ _LATE_INTERACTION_EXAMPLE_MODELS = {
# [Multimodal]
"ColModernVBertForRetrieval": _HfExamplesInfo(
"ModernVBERT/colmodernvbert-merged",
revision="4a0a9f3ac7a7992fec410bfa8e3d080ac9a5bcee",
),
"ColPaliForRetrieval": _HfExamplesInfo("vidore/colpali-v1.3-hf"),
"ColQwen3": _HfExamplesInfo(
@@ -804,6 +809,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
),
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
"Gemma4ForConditionalGeneration": _HfExamplesInfo(
"google/gemma-4-E2B-it",
min_transformers_version="5.5.0",
),
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it"),
"GlmAsrForConditionalGeneration": _HfExamplesInfo(
"zai-org/GLM-ASR-Nano-2512",

View File

@@ -239,6 +239,17 @@ def test_video_media_io_backend_env_var_fallback(monkeypatch: pytest.MonkeyPatch
assert metadata_missing["video_backend"] == "test_video_backend_override_2"
def _make_jpeg_b64_frames(n: int, width: int = 8, height: int = 8) -> list[str]:
"""Return *n* tiny base64-encoded JPEG frames."""
frames: list[str] = []
for i in range(n):
img = Image.new("RGB", (width, height), color=(i % 256, 0, 0))
buf = io.BytesIO()
img.save(buf, format="JPEG")
frames.append(pybase64.b64encode(buf.getvalue()).decode("ascii"))
return frames
def test_load_base64_jpeg_returns_metadata():
"""Regression test: load_base64 with video/jpeg must return metadata.
@@ -248,16 +259,8 @@ def test_load_base64_jpeg_returns_metadata():
"""
num_test_frames = 3
frame_width, frame_height = 8, 8
# Build a few tiny JPEG frames and base64-encode them
b64_frames = []
for i in range(num_test_frames):
img = Image.new("RGB", (frame_width, frame_height), color=(i * 80, 0, 0))
buf = io.BytesIO()
img.save(buf, format="JPEG")
b64_frames.append(pybase64.b64encode(buf.getvalue()).decode("ascii"))
b64_frames = _make_jpeg_b64_frames(num_test_frames)
data = ",".join(b64_frames)
imageio = ImageMediaIO()
@@ -287,3 +290,52 @@ def test_load_base64_jpeg_returns_metadata():
# Default fps=1 → duration == num_frames
assert metadata["fps"] == 1.0
assert metadata["duration"] == float(num_test_frames)
def test_load_base64_jpeg_enforces_num_frames_limit():
"""Frames beyond num_frames must be truncated in the video/jpeg path.
Without the limit an attacker can send thousands of base64 JPEG frames
in a single request and exhaust server memory (OOM).
"""
num_frames_limit = 4
sent_frames = 20
b64_frames = _make_jpeg_b64_frames(sent_frames)
data = ",".join(b64_frames)
imageio = ImageMediaIO()
videoio = VideoMediaIO(imageio, num_frames=num_frames_limit)
frames, metadata = videoio.load_base64("video/jpeg", data)
assert frames.shape[0] == num_frames_limit
assert metadata["total_num_frames"] == num_frames_limit
assert metadata["frames_indices"] == list(range(num_frames_limit))
def test_load_base64_jpeg_no_limit_when_num_frames_negative():
"""When num_frames is -1, all frames should be loaded without truncation."""
sent_frames = 10
b64_frames = _make_jpeg_b64_frames(sent_frames)
data = ",".join(b64_frames)
imageio = ImageMediaIO()
videoio = VideoMediaIO(imageio, num_frames=-1)
frames, metadata = videoio.load_base64("video/jpeg", data)
assert frames.shape[0] == sent_frames
assert metadata["total_num_frames"] == sent_frames
assert metadata["frames_indices"] == list(range(sent_frames))
def test_load_base64_jpeg_raises_on_zero_num_frames():
"""num_frames=0 is invalid and should raise ValueError."""
b64_frames = _make_jpeg_b64_frames(3)
data = ",".join(b64_frames)
imageio = ImageMediaIO()
videoio = VideoMediaIO(imageio, num_frames=0)
with pytest.raises(ValueError, match="num_frames must be greater than 0 or -1"):
videoio.load_base64("video/jpeg", data)

View File

@@ -0,0 +1,196 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
# Using mistral tokenizer as a generic mock since the actual model is not on HF
from vllm.tokenizers.registry import get_tokenizer
parser_name = "gemma4"
@pytest.fixture(scope="module")
def generic_tokenizer():
return get_tokenizer("google/gemma-4-E2B-it")
INVALID_SIMPLE_NONSTREAMING = {
"output": "This is a reasoning section<channel|>This is the rest",
"reasoning": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
INVALID_SIMPLE_STREAMING = {
"output": "This is a reasoning section<channel|>This is the rest",
"reasoning": None,
"content": "This is a reasoning sectionThis is the rest",
"is_reasoning_end": True,
}
INVALID_COMPLETE_NONSTREAMING = {
"output": "This is a reasoning section<channel|>",
"reasoning": "This is a reasoning section",
"content": None,
"is_reasoning_end": True,
}
INVALID_COMPLETE_STREAMING = {
"output": "This is a reasoning section<channel|>",
"reasoning": None,
"content": "This is a reasoning section",
"is_reasoning_end": True,
}
NO_CONTENT = {
"output": "<|channel>This is reasoning",
"reasoning": "This is reasoning",
"content": None,
"is_reasoning_end": False,
}
NO_REASONING = {
"output": "This is content",
"reasoning": None,
"content": "This is content",
"is_reasoning_end": False,
}
REASONING_WITH_CHANNEL = {
"output": "<|channel>This is a reasoning section<channel|>This is the rest",
"reasoning": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
COMPLETE_REASONING_WITH_CHANNEL = {
"output": "<|channel>This is a reasoning section<channel|>",
"reasoning": "This is a reasoning section",
"content": None,
"is_reasoning_end": True,
}
MULTIPLE_LINES_WITH_CHANNEL = {
"output": "<|channel>This\nThat<channel|>This is the rest\nThat",
"reasoning": "This\nThat",
"content": "This is the rest\nThat",
"is_reasoning_end": True,
}
CHANNEL_NO_END = {
"output": "<|channel>This is a reasoning section",
"reasoning": "This is a reasoning section",
"content": None,
"is_reasoning_end": False,
}
EMPTY = {
"output": "",
"reasoning": None,
"content": "",
"is_reasoning_end": False,
}
NEW_LINE_NONSTREAMING = {
"output": (
"Before\n<|channel>This is a reasoning section<channel|>\nThis is the rest"
),
"reasoning": "This is a reasoning section",
"content": "\nThis is the rest",
"is_reasoning_end": True,
}
NEW_LINE_STREAMING = {
"output": (
"Before\n<|channel>This is a reasoning section<channel|>\nThis is the rest"
),
"reasoning": "This is a reasoning section",
"content": "Before\n\nThis is the rest",
"is_reasoning_end": True,
}
TEST_CASES = [
pytest.param(False, INVALID_SIMPLE_NONSTREAMING, id="invalid_simple"),
pytest.param(True, INVALID_SIMPLE_STREAMING, id="invalid_simple_streaming"),
pytest.param(False, INVALID_COMPLETE_NONSTREAMING, id="invalid_complete"),
pytest.param(True, INVALID_COMPLETE_STREAMING, id="invalid_complete_streaming"),
pytest.param(False, NO_CONTENT, id="no_content"),
pytest.param(False, NO_REASONING, id="no_reasoning"),
pytest.param(False, REASONING_WITH_CHANNEL, id="reasoning"),
pytest.param(True, REASONING_WITH_CHANNEL, id="reasoning_streaming"),
pytest.param(False, COMPLETE_REASONING_WITH_CHANNEL, id="complete_reasoning"),
pytest.param(
True, COMPLETE_REASONING_WITH_CHANNEL, id="complete_reasoning_streaming"
),
pytest.param(False, MULTIPLE_LINES_WITH_CHANNEL, id="multiple_lines"),
pytest.param(True, MULTIPLE_LINES_WITH_CHANNEL, id="multiple_lines_streaming"),
pytest.param(False, CHANNEL_NO_END, id="no_end"),
pytest.param(True, CHANNEL_NO_END, id="no_end_streaming"),
pytest.param(False, EMPTY, id="empty"),
pytest.param(False, NEW_LINE_NONSTREAMING, id="new_line"),
pytest.param(True, NEW_LINE_STREAMING, id="new_line_streaming"),
]
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_gemma4_reasoning(
streaming: bool,
param_dict: dict,
generic_tokenizer,
):
output = param_dict["output"]
# Resolve token IDs dynamically from the real tokenizer
vocab = generic_tokenizer.get_vocab()
start_token_id = vocab["<|channel>"]
end_token_id = vocab["<channel|>"]
index_start = output.find("<|channel>")
len_start = len("<|channel>")
index_end = output.find("<channel|>")
len_end = len("<channel|>")
output_tokens = []
def _encode(text: str) -> list[int]:
if not text:
return []
# Handle both raw transformers and vLLM wrappers
enc = getattr(generic_tokenizer, "tokenizer", generic_tokenizer)
try:
return enc.encode(text, add_special_tokens=False)
except TypeError:
return enc.encode(text)
if index_start != -1:
output_before = output[:index_start]
output_tokens += _encode(output_before)
output_tokens += [start_token_id]
if index_end != -1:
output_middle = output[index_start + len_start : index_end]
output_after = output[index_end + len_end :]
output_tokens += _encode(output_middle)
output_tokens += [end_token_id]
output_tokens += _encode(output_after)
else:
output_middle = output[index_start + len_start :]
output_tokens += _encode(output_middle)
elif index_end != -1:
output_before = output[:index_end]
output_after = output[index_end + len_end :]
output_tokens += _encode(output_before)
output_tokens += [end_token_id]
output_tokens += _encode(output_after)
else:
output_tokens += _encode(output)
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
generic_tokenizer
)
# We use the generic run_reasoning_extraction from utils
# Use decode per token to get standard spaces instead of
# SentencePiece space characters
output_token_strings = [generic_tokenizer.decode([t]) for t in output_tokens]
reasoning, content = run_reasoning_extraction(
parser, output_token_strings, streaming=streaming
)
assert reasoning == param_dict["reasoning"]
assert content == param_dict["content"]
# Test is_reasoning_end
is_reasoning_end = parser.is_reasoning_end(output_tokens)
assert is_reasoning_end == param_dict["is_reasoning_end"]

View File

@@ -0,0 +1,504 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from typing import Any
from unittest.mock import MagicMock
import pytest
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.tool_parsers.gemma4_tool_parser import (
TOOL_CALL_END,
TOOL_CALL_START,
Gemma4ToolParser,
_parse_gemma4_args,
_parse_gemma4_array,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_tokenizer():
tokenizer = MagicMock()
tokenizer.encode.return_value = [1, 2, 3]
# Include the tool call start token in the vocab for the parser
tokenizer.get_vocab.return_value = {TOOL_CALL_START: 48, TOOL_CALL_END: 49}
return tokenizer
@pytest.fixture
def parser(mock_tokenizer):
return Gemma4ToolParser(mock_tokenizer)
@pytest.fixture
def mock_request():
request = MagicMock(spec=ChatCompletionRequest)
request.tools = []
request.tool_choice = "auto"
return request
# ---------------------------------------------------------------------------
# Unit tests for _parse_gemma4_args (shared parser logic)
# ---------------------------------------------------------------------------
class TestParseGemma4Args:
def test_empty_string(self):
assert _parse_gemma4_args("") == {}
def test_whitespace_only(self):
assert _parse_gemma4_args(" ") == {}
def test_single_string_value(self):
result = _parse_gemma4_args('location:<|"|>Paris<|"|>')
assert result == {"location": "Paris"}
def test_string_value_with_comma(self):
result = _parse_gemma4_args('location:<|"|>Paris, France<|"|>')
assert result == {"location": "Paris, France"}
def test_multiple_string_values(self):
result = _parse_gemma4_args(
'location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>'
)
assert result == {"location": "San Francisco", "unit": "celsius"}
def test_integer_value(self):
result = _parse_gemma4_args("count:42")
assert result == {"count": 42}
def test_float_value(self):
result = _parse_gemma4_args("score:3.14")
assert result == {"score": 3.14}
def test_boolean_true(self):
result = _parse_gemma4_args("flag:true")
assert result == {"flag": True}
def test_boolean_false(self):
result = _parse_gemma4_args("flag:false")
assert result == {"flag": False}
def test_mixed_types(self):
result = _parse_gemma4_args(
'name:<|"|>test<|"|>,count:42,active:true,score:3.14'
)
assert result == {
"name": "test",
"count": 42,
"active": True,
"score": 3.14,
}
def test_nested_object(self):
result = _parse_gemma4_args('nested:{inner:<|"|>value<|"|>}')
assert result == {"nested": {"inner": "value"}}
def test_array_of_strings(self):
result = _parse_gemma4_args('items:[<|"|>a<|"|>,<|"|>b<|"|>]')
assert result == {"items": ["a", "b"]}
def test_unterminated_string(self):
"""Unterminated strings should take everything after the delimiter."""
result = _parse_gemma4_args('key:<|"|>unterminated')
assert result == {"key": "unterminated"}
def test_empty_value(self):
"""Key with no value after colon."""
result = _parse_gemma4_args("key:")
assert result == {"key": ""}
class TestParseGemma4Array:
def test_string_array(self):
result = _parse_gemma4_array('<|"|>a<|"|>,<|"|>b<|"|>')
assert result == ["a", "b"]
def test_empty_array(self):
result = _parse_gemma4_array("")
assert result == []
def test_bare_values(self):
result = _parse_gemma4_array("42,true,3.14")
assert result == [42, True, 3.14]
# ---------------------------------------------------------------------------
# Non-streaming extraction tests
# ---------------------------------------------------------------------------
class TestExtractToolCalls:
def test_no_tool_calls(self, parser, mock_request):
model_output = "Hello, how can I help you today?"
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is False
assert result.tool_calls == []
assert result.content == model_output
def test_single_tool_call(self, parser, mock_request):
model_output = (
'<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|>'
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "get_weather"
args = json.loads(result.tool_calls[0].function.arguments)
assert args == {"location": "London"}
def test_multiple_arguments(self, parser, mock_request):
model_output = (
"<|tool_call>call:get_weather{"
'location:<|"|>San Francisco<|"|>,'
'unit:<|"|>celsius<|"|>}'
"<tool_call|>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "get_weather"
args = json.loads(result.tool_calls[0].function.arguments)
assert args == {"location": "San Francisco", "unit": "celsius"}
def test_text_before_tool_call(self, parser, mock_request):
model_output = (
"Let me check the weather for you. "
'<|tool_call>call:get_weather{location:<|"|>Paris<|"|>}'
"<tool_call|>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert result.content == "Let me check the weather for you."
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "get_weather"
def test_multiple_tool_calls(self, parser, mock_request):
model_output = (
'<|tool_call>call:get_weather{location:<|"|>London<|"|>}'
"<tool_call|>"
'<|tool_call>call:get_time{location:<|"|>London<|"|>}'
"<tool_call|>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert len(result.tool_calls) == 2
assert result.tool_calls[0].function.name == "get_weather"
assert result.tool_calls[1].function.name == "get_time"
def test_nested_arguments(self, parser, mock_request):
model_output = (
"<|tool_call>call:complex_function{"
'nested:{inner:<|"|>value<|"|>},'
'list:[<|"|>a<|"|>,<|"|>b<|"|>]}'
"<tool_call|>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "complex_function"
args = json.loads(result.tool_calls[0].function.arguments)
assert args == {"nested": {"inner": "value"}, "list": ["a", "b"]}
def test_tool_call_with_number_and_boolean(self, parser, mock_request):
model_output = (
"<|tool_call>call:set_status{"
"is_active:true,"
"count:42,"
"score:3.14}"
"<tool_call|>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "set_status"
args = json.loads(result.tool_calls[0].function.arguments)
assert args == {"is_active": True, "count": 42, "score": 3.14}
def test_incomplete_tool_call(self, parser, mock_request):
model_output = '<|tool_call>call:get_weather{location:<|"|>London'
result = parser.extract_tool_calls(model_output, mock_request)
# Incomplete — no <tool_call|> end marker, regex won't match
assert result.tools_called is False
assert result.content == model_output
def test_hyphenated_function_name(self, parser, mock_request):
"""Ensure function names with hyphens are parsed correctly."""
model_output = (
'<|tool_call>call:get-weather{location:<|"|>London<|"|>}<tool_call|>'
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert result.tool_calls[0].function.name == "get-weather"
def test_dotted_function_name(self, parser, mock_request):
"""Ensure function names with dots are parsed correctly."""
model_output = (
'<|tool_call>call:weather.get{location:<|"|>London<|"|>}<tool_call|>'
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert result.tool_calls[0].function.name == "weather.get"
def test_no_arguments(self, parser, mock_request):
"""Tool calls with empty arguments."""
model_output = "<|tool_call>call:get_status{}<tool_call|>"
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert result.tool_calls[0].function.name == "get_status"
args = json.loads(result.tool_calls[0].function.arguments)
assert args == {}
# ---------------------------------------------------------------------------
# Streaming extraction tests
# ---------------------------------------------------------------------------
class TestStreamingExtraction:
"""Tests for the streaming tool call extraction.
These simulate the token-by-token streaming that vLLM performs,
feeding incremental text to extract_tool_calls_streaming() and
verifying that the accumulated argument deltas form valid JSON.
"""
def _simulate_streaming(
self, parser: Gemma4ToolParser, mock_request: Any, chunks: list[str]
) -> list[tuple[Any, str]]:
"""Feed chunks through the streaming parser and collect results.
Returns a list of (delta_message, accumulated_text) tuples.
"""
results: list[tuple[Any, str]] = []
previous_text: str = ""
previous_token_ids: list[int] = []
for chunk in chunks:
current_text = previous_text + chunk
# Use token ID 48 for tool_call start, 49 for end, 0 otherwise
delta_token_ids: list[int] = []
if TOOL_CALL_START in chunk:
delta_token_ids.append(48)
elif TOOL_CALL_END in chunk:
delta_token_ids.append(49)
else:
delta_token_ids.append(0)
current_token_ids = previous_token_ids + delta_token_ids
delta = parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=chunk,
previous_token_ids=tuple(previous_token_ids),
current_token_ids=tuple(current_token_ids),
delta_token_ids=tuple(delta_token_ids),
request=mock_request,
)
results.append((delta, current_text))
previous_text = current_text
previous_token_ids = list(current_token_ids)
return results
def _collect_arguments(self, results):
"""Collect all argument deltas from streaming results into one string."""
args_text = ""
for delta, _ in results:
if delta and delta.tool_calls:
for tc in delta.tool_calls:
func = tc.function if isinstance(tc.function, dict) else tc.function
if isinstance(func, dict):
arg = func.get("arguments", "")
else:
arg = getattr(func, "arguments", "") or ""
if arg:
args_text += arg
return args_text
def _collect_function_name(self, results):
"""Extract the function name from streaming results."""
for delta, _ in results:
if delta and delta.tool_calls:
for tc in delta.tool_calls:
func = tc.function if isinstance(tc.function, dict) else tc.function
if isinstance(func, dict):
name = func.get("name")
else:
name = getattr(func, "name", None)
if name:
return name
return None
def test_basic_streaming_single_tool(self, parser, mock_request):
"""Simulate the exact streaming scenario from the bug report.
Model generates:
<|tool_call>call:get_weather{location:<|"|>Paris, France<|"|>}<tool_call|>
Expected: arguments should be valid JSON {"location": "Paris, France"}
"""
chunks = [
"<|tool_call>",
"call:get_weather{",
'location:<|"|>Paris',
", France",
'<|"|>}',
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
# Verify function name
name = self._collect_function_name(results)
assert name == "get_weather", f"Expected 'get_weather', got '{name}'"
# Verify arguments form valid JSON
args_text = self._collect_arguments(results)
assert args_text, "No arguments were streamed"
parsed_args = json.loads(args_text)
assert parsed_args == {"location": "Paris, France"}
def test_streaming_multi_arg(self, parser, mock_request):
"""Streaming with multiple arguments."""
chunks = [
"<|tool_call>",
"call:get_weather{",
'location:<|"|>Tokyo<|"|>,',
'unit:<|"|>celsius<|"|>}',
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
name = self._collect_function_name(results)
assert name == "get_weather"
args_text = self._collect_arguments(results)
assert args_text
parsed_args = json.loads(args_text)
assert parsed_args == {"location": "Tokyo", "unit": "celsius"}
def test_streaming_no_extra_brace(self, parser, mock_request):
"""Verify the closing } is NOT leaked into arguments (Bug #2)."""
chunks = [
"<|tool_call>",
"call:get_weather{",
'location:<|"|>London<|"|>}',
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
args_text = self._collect_arguments(results)
assert args_text
# The args text must be valid JSON (no extra })
parsed = json.loads(args_text)
assert parsed == {"location": "London"}
# Specifically assert no double-brace
assert args_text.count("}") <= 1, (
f"Arguments contain extra closing brace: {args_text!r}"
)
def test_streaming_no_unquoted_keys(self, parser, mock_request):
"""Verify keys are properly quoted in JSON (Bug #1)."""
chunks = [
"<|tool_call>",
"call:get_weather{",
'location:<|"|>Paris<|"|>}',
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
args_text = self._collect_arguments(results)
# Must start with { and contain quoted key
assert args_text.lstrip().startswith("{"), (
f"Arguments don't start with '{{': {args_text!r}"
)
assert '"location"' in args_text, (
f"Key 'location' not properly quoted: {args_text!r}"
)
def test_streaming_name_no_call_prefix(self, parser, mock_request):
"""Verify function name has no 'call:' prefix."""
chunks = [
"<|tool_call>",
"call:get_weather{",
'location:<|"|>Paris<|"|>}',
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
name = self._collect_function_name(results)
assert name == "get_weather"
assert not name.startswith("call:"), f"Name has 'call:' prefix: {name!r}"
def test_streaming_text_before_tool_call(self, parser, mock_request):
"""Text before tool call should be emitted as content."""
chunks = [
"Let me check ",
"the weather. ",
"<|tool_call>",
"call:get_weather{",
'location:<|"|>London<|"|>}',
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
# First chunks should be content
content_parts = []
for delta, _ in results:
if delta and delta.content:
content_parts.append(delta.content)
assert "".join(content_parts).strip().startswith("Let me check")
def test_streaming_numeric_args(self, parser, mock_request):
"""Streaming with numeric and boolean argument values."""
chunks = [
"<|tool_call>",
"call:set_config{",
"count:42,",
"active:true}",
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
args_text = self._collect_arguments(results)
if args_text:
parsed_args = json.loads(args_text)
assert parsed_args["count"] == 42
assert parsed_args["active"] is True
def test_streaming_empty_args(self, parser, mock_request):
"""Tool call with no arguments."""
chunks = [
"<|tool_call>",
"call:get_status{}",
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
name = self._collect_function_name(results)
assert name == "get_status"

View File

@@ -42,6 +42,7 @@ from vllm.v1.attention.backends.mla.flashmla_sparse import (
FlashMLASparseBackend,
triton_convert_req_index_to_global_index,
)
from vllm.v1.attention.backends.mla.indexer import split_indexer_prefill_chunks
from vllm.v1.attention.backends.utils import split_prefill_chunks
from vllm.v1.attention.ops import flashmla
@@ -716,6 +717,81 @@ def test_split_prefill_chunks(seq_lens, max_buf, expected):
assert out == expected
@pytest.mark.parametrize(
"seq_lens,query_lens,workspace_size,max_logits_bytes,expected",
[
# Logits constraint triggers split (M*N exceeds budget)
# req0: M=10, N=100 -> 1000 elems (4000 bytes) - fits in 5000
# req1: adding M=10, N=100 -> new_M=20, new_N=200 -> 4000 elems > 1250
(
torch.tensor([100, 100, 100]),
torch.tensor([10, 10, 10]),
1000, # workspace allows all
5000, # 1250 float32 elems -> forces split
[
(slice(0, 1), slice(0, 10)),
(slice(1, 2), slice(0, 10)),
(slice(2, 3), slice(0, 10)),
],
),
# Both constraints satisfied - all fit in one chunk
(
torch.tensor([10, 10, 10]),
torch.tensor([5, 5, 5]),
100,
10000, # 2500 elems, M*N = 15*30 = 450 < 2500
[(slice(0, 3), slice(0, 15))],
),
# Workspace constraint triggers first
(
torch.tensor([50, 50, 50]),
torch.tensor([1, 1, 1]),
50, # workspace only fits one at a time
1000000, # logits budget is huge
[
(slice(0, 1), slice(0, 1)),
(slice(1, 2), slice(0, 1)),
(slice(2, 3), slice(0, 1)),
],
),
# Greedy filling: first two fit, third doesn't
# req0: M=5, N=10 -> 50 elems
# req0+1: M=10, N=20 -> 200 elems <= 250
# req0+1+2: M=15, N=30 -> 450 elems > 250
(
torch.tensor([10, 10, 10]),
torch.tensor([5, 5, 5]),
100,
1000, # 250 elems
[(slice(0, 2), slice(0, 10)), (slice(2, 3), slice(0, 5))],
),
],
)
def test_split_indexer_prefill_chunks(
seq_lens, query_lens, workspace_size, max_logits_bytes, expected
):
out = split_indexer_prefill_chunks(
seq_lens,
query_lens,
workspace_size,
max_logits_bytes,
)
assert out == expected
def test_split_indexer_prefill_chunks_single_request_overflow():
"""Test that single request exceeding budget is sub-chunked on query dim."""
seq_lens = torch.tensor([1000, 50])
query_lens = torch.tensor([100, 5])
out = split_indexer_prefill_chunks(seq_lens, query_lens, 2000, 1000)
# max_logits_elems = 250, N=1000 -> max_q = 1 -> 100 query sub-chunks
expected = [(slice(0, 1), slice(i, i + 1)) for i in range(100)]
# req1: M=5, N=50 -> 250 elems fits budget
expected.append((slice(1, 2), slice(0, 5)))
assert out == expected
def test_triton_convert_returns_valid_counts():
"""Test that return_valid_counts correctly counts non-negative indices."""
device = torch.device("cuda")

View File

@@ -20,7 +20,7 @@ def server():
"--reasoning-parser",
"qwen3",
"--reasoning-config",
'{"think_start_str": "<think>", "think_end_str": "</think>"}',
'{"reasoning_start_str": "<think>", "reasoning_end_str": "</think>"}',
"--max-model-len",
"2048",
"--enforce-eager",

View File

@@ -103,8 +103,8 @@ class LogitsProcsRequestParams:
class MockReasoningConfig:
"""Mock reasoning config for testing ThinkingTokenBudgetLogitsProcessor."""
think_start_token_ids = [THINK_START_TOKEN_ID]
think_end_token_ids = [THINK_END_TOKEN_ID]
reasoning_start_token_ids = [THINK_START_TOKEN_ID]
reasoning_end_token_ids = [THINK_END_TOKEN_ID]
def _generate_fake_sampling_metadata(
@@ -491,7 +491,7 @@ def _thinking_budget_validate(
# Find if thinking has started in output tokens
thinking_started = False
start_tokens = tb_processor.think_start_token_ids
start_tokens = tb_processor.reasoning_start_token_ids
if len(start_tokens) > 0:
for i in range(len(output_tokens) - len(start_tokens) + 1):
@@ -518,7 +518,7 @@ def _thinking_budget_validate(
)
# Validate that only end tokens are allowed
end_tokens = tb_processor.think_end_token_ids
end_tokens = tb_processor.reasoning_end_token_ids
if len(end_tokens) > 0:
expected_end_token_id = end_tokens[
min(state["end_count"], len(end_tokens) - 1)

View File

View File

@@ -0,0 +1,193 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Integration tests for SimpleCPUOffloadConnector with real models."""
import time
import pytest
from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import KVTransferConfig
from vllm.platforms import current_platform
if not current_platform.is_cuda():
pytest.skip("Requires CUDA", allow_module_level=True)
# Small models for default CI / local runs (accuracy only).
SMALL_MODELS = [
"meta-llama/Llama-3.2-1B-Instruct",
"google/gemma-3-1b-it",
]
# Large models for optional perf runs only (slow to load and execute).
PERF_MODELS = [
"meta-llama/Llama-3.1-8B",
"openai/gpt-oss-20b",
]
def _make_llm(model: str, lazy: bool, cpu_bytes_to_use: int) -> LLM:
kv_transfer_config = KVTransferConfig(
kv_connector="SimpleCPUOffloadConnector",
kv_role="kv_both",
kv_connector_extra_config={
"cpu_bytes_to_use": cpu_bytes_to_use,
"lazy_offload": lazy,
},
)
return LLM(
model=model,
kv_cache_memory_bytes=40 << 30, # 40 GiB
disable_hybrid_kv_cache_manager=False,
enable_prefix_caching=True,
kv_transfer_config=kv_transfer_config,
)
def _flush_gpu_cache(llm: LLM, sampling_params: SamplingParams, seed: int = 0):
"""Generate enough filler requests to allocate the entire GPU KV cache.
This pushes all prior blocks through the free queue so that the lazy
cursor offloads them to CPU before they are evicted.
"""
cache_config = llm.llm_engine.vllm_config.cache_config
num_gpu_blocks = cache_config.num_gpu_blocks
block_size = cache_config.block_size
# Use 1.2x GPU capacity to give the lazy cursor enough scheduling steps
# to walk past all target blocks near the tail of the free queue.
total_tokens_needed = int(num_gpu_blocks * block_size * 1.5)
# Use token-id prompts so each filler is unique (no prefix sharing).
# Split into multiple requests to stay under max_model_len.
max_tokens_per_req = 4096
num_fillers = (total_tokens_needed + max_tokens_per_req - 1) // max_tokens_per_req
batch_size = 10
for i in range(0, num_fillers, batch_size):
batch_end = min(i + batch_size, num_fillers)
filler_prompts = []
for j in range(i, batch_end):
ids = [seed * num_fillers + j + 1] * max_tokens_per_req
filler_prompts.append(TokensPrompt(prompt_token_ids=ids))
llm.generate(filler_prompts, sampling_params, use_tqdm=False)
def _accuracy_test(llm: LLM, lazy: bool = False):
"""Verify that CPU-loaded KV produces correct output."""
sampling_params = SamplingParams(max_tokens=1, temperature=0)
prompt = "hi " * 2000 + "Let's count to ten. One, two, three, "
# Cold run — populate GPU cache and trigger CPU offload
cold_output = llm.generate(prompt, sampling_params, use_tqdm=False)[0]
# CPU hit runs
test_count = 10
success_count = 0
expected = cold_output.outputs[0].text
for i in range(test_count):
if lazy:
_flush_gpu_cache(llm, sampling_params, seed=i)
time.sleep(2) # let engine core drain pending transfers
# Reset GPU prefix cache so next run must load from CPU
if not llm.reset_prefix_cache():
print(f"GPU prefix cache reset failed for iteration {i}")
output = llm.generate(prompt, sampling_params, use_tqdm=False)[0]
if output.outputs[0].text == expected:
success_count += 1
assert success_count >= 0.5 * test_count, (
f"Accuracy too low: {success_count}/{test_count} matched '{expected}'"
)
def _latency_test(llm: LLM, lazy: bool = False):
"""Verify CPU cache hit is faster than cold compute."""
sampling_params = SamplingParams(max_tokens=1, seed=42)
prompt_token_ids = [0] * 10001
num_times_cpu_better = 0
num_tests = 10
for i in range(num_tests):
prompt_token_ids[0] = i
prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)]
# Cold
time.sleep(2) # let engine core drain pending transfers
if not llm.reset_prefix_cache():
print(f"GPU prefix cache reset failed for iteration {i}")
start = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cold_time = time.time() - start
if lazy:
_flush_gpu_cache(llm, sampling_params, seed=i)
else:
# Eager mode: GPU hit ensures store completion is processed.
llm.generate(prompts, sampling_params, use_tqdm=False)
time.sleep(2) # let engine core drain pending transfers
if not llm.reset_prefix_cache():
print(f"GPU prefix cache reset failed for iteration {i}")
# CPU hit
start = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cpu_time = time.time() - start
if cpu_time < cold_time:
num_times_cpu_better += 1
assert num_times_cpu_better >= 0.8 * num_tests, (
f"CPU hit only faster {num_times_cpu_better}/{num_tests} times"
)
@pytest.mark.optional
@pytest.mark.slow_test
@pytest.mark.parametrize("model", SMALL_MODELS)
def test_simple_cpu_offload_accuracy(model: str):
"""Store to CPU, reset GPU, load from CPU; verify output matches baseline."""
llm = _make_llm(model, False, 1 << 30) # 1GB
try:
_accuracy_test(llm, lazy=False)
finally:
del llm
@pytest.mark.optional
@pytest.mark.slow_test
@pytest.mark.parametrize("model", PERF_MODELS)
def test_simple_cpu_offload_perf_latency(model: str):
"""CPU KV hit should beat cold prefill on long context (large models only)."""
llm = _make_llm(model, False, 10 << 30) # 10GB
try:
_latency_test(llm, lazy=False)
finally:
del llm
@pytest.mark.optional
@pytest.mark.slow_test
@pytest.mark.parametrize("model", SMALL_MODELS)
def test_simple_cpu_offload_accuracy_lazy(model: str):
"""Lazy mode: flush GPU cache to trigger CPU offload, then verify hit."""
# CPU must be larger than GPU KV cache to avoid evicting offloaded blocks.
llm = _make_llm(model, True, 80 << 30) # 80GB
try:
_accuracy_test(llm, lazy=True)
finally:
del llm
@pytest.mark.optional
@pytest.mark.slow_test
@pytest.mark.parametrize("model", PERF_MODELS)
def test_simple_cpu_offload_perf_latency_lazy(model: str):
"""Lazy mode: CPU KV hit should beat cold prefill (large models only)."""
# CPU must be larger than GPU KV cache to avoid evicting offloaded blocks.
llm = _make_llm(model, True, 80 << 30) # 80GB
try:
_latency_test(llm, lazy=True)
finally:
del llm

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Regression tests for the backup token fix in prepare_next_token_ids_padded.
Fixes #38098: with async scheduling, seq_lens_cpu is inflated by unaccepted
draft token placeholders, causing get_token_id() to return -1.
"""
from __future__ import annotations
import numpy as np
import pytest
import torch
class _FakeRequest:
def __init__(self, prompt_tokens: list[int], output_tokens: list[int]):
self.num_prompt_tokens = len(prompt_tokens)
self._prompt = prompt_tokens
self._output = output_tokens
@property
def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self._output)
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
return self._prompt[idx]
out_idx = idx - self.num_prompt_tokens
if out_idx < len(self._output):
return self._output[out_idx]
return -1 # out of range
class _FakeInputBatch:
def __init__(
self,
req_ids: list[str],
num_tokens_no_spec: list[int],
vocab_size: int = 32000,
):
self.req_ids = req_ids
self.num_reqs = len(req_ids)
self.vocab_size = vocab_size
self.num_tokens_no_spec = np.array(num_tokens_no_spec, dtype=np.int64)
def _make_requests(
req_ids: list[str],
prompt_lens: list[int],
output_lens: list[int],
) -> dict[str, _FakeRequest]:
requests = {}
for rid, plen, olen in zip(req_ids, prompt_lens, output_lens):
requests[rid] = _FakeRequest(list(range(plen)), list(range(1000, 1000 + olen)))
return requests
def _backup_buggy(
seq_lens_cpu: torch.Tensor,
requests: dict[str, _FakeRequest],
batch: _FakeInputBatch,
) -> list[int]:
"""Old logic: uses seq_lens_cpu directly (may be inflated)."""
n = batch.num_reqs
return [
requests[batch.req_ids[i]].get_token_id(int(seq_lens_cpu[i])) for i in range(n)
]
def _backup_fixed(
requests: dict[str, _FakeRequest],
batch: _FakeInputBatch,
) -> list[int]:
"""New logic: uses num_tokens_no_spec - 1 (last committed token)."""
n = batch.num_reqs
idx = (batch.num_tokens_no_spec[:n] - 1).tolist()
return [requests[batch.req_ids[i]].get_token_id(int(idx[i])) for i in range(n)]
class TestBackupTokenAsyncSpec:
def test_no_inflation_fixed_returns_last_token(self):
req_ids = ["r0", "r1"]
requests = _make_requests(req_ids, [3, 3], [2, 2])
batch = _FakeInputBatch(req_ids, [5, 5])
# idx = 5-1 = 4 → output[1] = 1001
assert _backup_fixed(requests, batch) == [1001, 1001]
def test_inflation_buggy_returns_placeholder(self):
req_ids = ["r0", "r1"]
requests = _make_requests(req_ids, [3, 3], [2, 2])
batch = _FakeInputBatch(req_ids, [5, 5])
# inflated by 3 spec tokens → idx 8 is out of range
seq_lens = torch.tensor([8, 8], dtype=torch.int64)
assert _backup_buggy(seq_lens, requests, batch) == [-1, -1]
def test_inflation_fixed_returns_correct_token(self):
req_ids = ["r0", "r1"]
requests = _make_requests(req_ids, [3, 3], [2, 2])
batch = _FakeInputBatch(req_ids, [5, 5])
assert _backup_fixed(requests, batch) == [1001, 1001]
def test_mixed_inflation_per_request(self):
req_ids = ["r0", "r1", "r2"]
requests = {
"r0": _FakeRequest([0, 1], [1000, 1001, 1002]),
"r1": _FakeRequest([0, 1, 2, 3], [2000]),
"r2": _FakeRequest([0], [3000, 3001, 3002, 3003]),
}
batch = _FakeInputBatch(req_ids, [5, 5, 5])
seq_lens = torch.tensor([7, 9, 5], dtype=torch.int64)
assert _backup_buggy(seq_lens, requests, batch) == [-1, -1, -1]
assert _backup_fixed(requests, batch) == [1002, 2000, 3003]
def test_prefill_only_request(self):
"""No output tokens yet — backup should be the last prompt token."""
req_ids = ["r0"]
requests = {"r0": _FakeRequest([10, 20, 30], [])}
batch = _FakeInputBatch(req_ids, [3])
# idx = 3-1 = 2 → prompt[2] = 30
assert _backup_fixed(requests, batch) == [30]
@pytest.mark.parametrize("num_spec_tokens", [1, 2, 3, 4, 5])
def test_various_spec_token_counts(self, num_spec_tokens: int):
req_ids = ["r0"]
requests = {"r0": _FakeRequest([0, 1, 2], list(range(1000, 1005)))}
batch = _FakeInputBatch(req_ids, [8])
# idx = 8-1 = 7 → output[4] = 1004
assert _backup_fixed(requests, batch) == [1004]
def test_buggy_code_was_always_off_by_one(self):
"""The original code used seq_len as index, which is always one past
the end of output_token_ids even without async inflation."""
req_ids = ["r0"]
requests = {"r0": _FakeRequest([0, 1, 2], [1000, 1001])}
batch = _FakeInputBatch(req_ids, [5])
# no inflation: seq_len == num_tokens == 5 → idx 5 is out of range
seq_lens = torch.tensor([5], dtype=torch.int64)
assert _backup_buggy(seq_lens, requests, batch) == [-1]
assert _backup_fixed(requests, batch) == [1001]
# with inflation: still -1, fixed still correct
seq_lens_inf = torch.tensor([8], dtype=torch.int64)
assert _backup_buggy(seq_lens_inf, requests, batch) == [-1]
assert _backup_fixed(requests, batch) == [1001]

View File

@@ -3,6 +3,7 @@
from unittest import mock
import numpy as np
import pytest
import torch
@@ -111,16 +112,14 @@ def test_prepare_next_token_ids():
num_requests = 4
num_speculative_tokens = 4
batch_spec = BatchSpec(
seq_lens=[num_speculative_tokens + 1] * num_requests,
query_lens=[num_speculative_tokens + 1] * num_requests,
)
req_ids = [f"req_{i + 1}" for i in range(num_requests)]
mock_input_batch = mock.MagicMock(spec=InputBatch)
mock_input_batch.req_ids = req_ids
mock_input_batch.num_reqs = num_requests
mock_input_batch.vocab_size = 100
mock_input_batch.num_tokens_no_spec = np.array(
[num_speculative_tokens + 1] * num_requests
)
mock_num_scheduled_tokens = {req_id: 0 for req_id in req_ids}
mock_requests = {}
@@ -165,19 +164,12 @@ def test_prepare_next_token_ids():
assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=BLOCK_SIZE,
device=device,
)
expected_valid_sampled_tokens_count = torch.tensor(
[2, 5, 0, 0], dtype=torch.int32, device=device
)
next_token_ids_from_padded, valid_sampled_tokens_count = (
proposer.prepare_next_token_ids_padded(
common_attn_metadata.seq_lens_cpu,
sampled_token_ids_tensor,
mock_requests,
mock_input_batch,

View File

@@ -3,6 +3,7 @@
from unittest import mock
import numpy as np
import pytest
import torch
@@ -132,16 +133,12 @@ def test_prepare_next_token_ids_padded():
device = torch.device(current_platform.device_type)
num_requests = 4
batch_spec = BatchSpec(
seq_lens=[5] * num_requests,
query_lens=[5] * num_requests,
)
req_ids = [f"req_{i + 1}" for i in range(num_requests)]
mock_input_batch = mock.MagicMock(spec=InputBatch)
mock_input_batch.req_ids = req_ids
mock_input_batch.num_reqs = num_requests
mock_input_batch.vocab_size = 100
mock_input_batch.num_tokens_no_spec = np.array([5] * num_requests)
mock_requests = {}
for req_id in req_ids:
@@ -174,12 +171,6 @@ def test_prepare_next_token_ids_padded():
proposer = _create_proposer(num_speculative_tokens=1)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)
# valid_sampled_tokens_count tracks if token is valid (not -1 and in vocab range)
# It doesn't depend on whether the request is discarded
expected_valid_sampled_tokens_count = torch.tensor(
@@ -187,7 +178,6 @@ def test_prepare_next_token_ids_padded():
)
next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded(
common_attn_metadata.seq_lens_cpu,
sampled_token_ids,
mock_requests,
mock_input_batch,

View File

@@ -12,7 +12,7 @@ from vllm.tokenizers import cached_tokenizer_from_config
class ReasoningConfig:
"""Configuration for reasoning models.
Set `think_start_str` and `think_end_str` to the strings that delimit
Set `reasoning_start_str` and `reasoning_end_str` to the strings that delimit
the reasoning block (e.g. `"<think>"` and `"</think>"`). The
corresponding token IDs are derived automatically via
`initialize_token_ids` and are not intended to be set directly.
@@ -20,53 +20,55 @@ class ReasoningConfig:
# NOTE: These parameters are temporary, the intent is to derive them
# automatically from the reasoning parser in a future version.
think_start_str: str = "<think>"
reasoning_start_str: str = "<think>"
"""String that indicates the start of reasoning."""
think_end_str: str = "</think>"
reasoning_end_str: str = "</think>"
"""String that indicates the end of reasoning content."""
_think_start_token_ids: list[int] | None = field(
_reasoning_start_token_ids: list[int] | None = field(
default=None, init=False, repr=False
)
"""Private backing field for `think_start_token_ids`. Set by
"""Private backing field for `reasoning_start_token_ids`. Set by
`initialize_token_ids`. Not intended to be configured directly."""
_think_end_token_ids: list[int] | None = field(default=None, init=False, repr=False)
"""Private backing field for `think_end_token_ids`. Set by
_reasoning_end_token_ids: list[int] | None = field(
default=None, init=False, repr=False
)
"""Private backing field for `reasoning_end_token_ids`. Set by
`initialize_token_ids`. Not intended to be configured directly."""
@property
def think_start_token_ids(self) -> list[int] | None:
"""Token IDs derived from `think_start_str`. Set automatically by
def reasoning_start_token_ids(self) -> list[int] | None:
"""Token IDs derived from `reasoning_start_str`. Set automatically by
`initialize_token_ids`. Not intended to be configured directly."""
return self._think_start_token_ids
return self._reasoning_start_token_ids
@property
def think_end_token_ids(self) -> list[int] | None:
"""Token IDs derived from `think_end_str`. Set automatically by
def reasoning_end_token_ids(self) -> list[int] | None:
"""Token IDs derived from `reasoning_end_str`. Set automatically by
`initialize_token_ids`. Not intended to be configured directly."""
return self._think_end_token_ids
return self._reasoning_end_token_ids
def initialize_token_ids(self, model_config: ModelConfig) -> None:
"""Initialize reasoning token IDs from strings using the tokenizer."""
if (
self._think_start_token_ids is not None
and self._think_end_token_ids is not None
self._reasoning_start_token_ids is not None
and self._reasoning_end_token_ids is not None
):
return
tokenizer = cached_tokenizer_from_config(model_config=model_config)
self._think_start_token_ids = tokenizer.encode(
self.think_start_str, add_special_tokens=False
self._reasoning_start_token_ids = tokenizer.encode(
self.reasoning_start_str, add_special_tokens=False
)
self._think_end_token_ids = tokenizer.encode(
self.think_end_str, add_special_tokens=False
self._reasoning_end_token_ids = tokenizer.encode(
self.reasoning_end_str, add_special_tokens=False
)
if not self._think_start_token_ids or not self._think_end_token_ids:
if not self._reasoning_start_token_ids or not self._reasoning_end_token_ids:
raise ValueError(
f"ReasoningConfig: failed to tokenize reasoning strings: "
f"think_start_str='{self.think_start_str}', "
f"think_end_str='{self.think_end_str}'. "
f"reasoning_start_str='{self.reasoning_start_str}', "
f"reasoning_end_str='{self.reasoning_end_str}'. "
"Ensure the strings are valid tokens in the model's vocabulary."
)

View File

@@ -657,7 +657,11 @@ class VllmConfig:
)
if kv_offloading_backend == "native":
self.kv_transfer_config.kv_connector = "OffloadingConnector"
if envs.VLLM_USE_SIMPLE_KV_OFFLOAD:
config_connector = "SimpleCPUOffloadConnector"
else:
config_connector = "OffloadingConnector"
self.kv_transfer_config.kv_connector = config_connector
self.kv_transfer_config.kv_connector_extra_config.update(
{"cpu_bytes_to_use": kv_offloading_size * (1 << 30)}
)

View File

@@ -202,6 +202,7 @@ KVConnectorFactory.register_connector(
"vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector",
"DecodeBenchConnector",
)
KVConnectorFactory.register_connector(
"MooncakeConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector",
@@ -213,3 +214,9 @@ KVConnectorFactory.register_connector(
"vllm.distributed.kv_transfer.kv_connector.v1.flexkv_connector",
"FlexKVConnectorV1",
)
KVConnectorFactory.register_connector(
"SimpleCPUOffloadConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.simple_cpu_offload_connector",
"SimpleCPUOffloadConnector",
)

View File

@@ -0,0 +1,247 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""SimpleCPUOffloadConnector: minimal CPU KV cache offloading."""
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
SupportsHMA,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.simple_kv_offload.manager import (
SimpleCPUOffloadScheduler,
)
from vllm.v1.simple_kv_offload.metadata import (
SimpleCPUOffloadMetadata,
)
from vllm.v1.simple_kv_offload.worker import (
SimpleCPUOffloadWorker,
)
if TYPE_CHECKING:
from vllm.forward_context import ForwardContext
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
# Default CPU capacity: 8 GB
DEFAULT_CPU_CAPACITY_BYTES = 8 * (1024**3)
class SimpleCPUOffloadConnector(KVConnectorBase_V1, SupportsHMA):
"""CPU KV cache offloading with custom kernel transfers and BlockPool LRU."""
def __init__(
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig | None" = None,
):
super().__init__(vllm_config, role, kv_cache_config)
enable_prefix_caching = vllm_config.cache_config.enable_prefix_caching
extra_config = self._kv_transfer_config.kv_connector_extra_config or {}
cpu_capacity_bytes = int(
extra_config.get("cpu_bytes_to_use", DEFAULT_CPU_CAPACITY_BYTES)
)
# cpu_bytes_to_use is server-wide for compatibility;
# cpu_bytes_to_use_per_rank overrides for per-rank capacity.
world_size = vllm_config.parallel_config.world_size
cpu_capacity_per_rank = cpu_capacity_bytes // world_size
if "cpu_bytes_to_use_per_rank" in extra_config:
explicit = int(extra_config["cpu_bytes_to_use_per_rank"])
if explicit != cpu_capacity_per_rank:
logger.warning(
"cpu_bytes_to_use_per_rank (%.2f GB) != "
"cpu_bytes_to_use/world_size (%.2f GB). Using per-rank value.",
explicit / (1024**3),
cpu_capacity_per_rank / (1024**3),
)
cpu_capacity_per_rank = explicit
lazy_offload = bool(extra_config.get("lazy_offload", False))
self.scheduler_manager: SimpleCPUOffloadScheduler | None = None
self.worker_handler: SimpleCPUOffloadWorker | None = None
if not enable_prefix_caching:
logger.warning(
"Detected prefix caching disabled, disabling CPU offload "
"since it requires prefix caching."
)
return
logger.info(
"SimpleCPUOffloadConnector: role=%s, "
"per_rank=%.2f GB, world_size=%d, mode=%s",
role.name,
cpu_capacity_per_rank / (1024**3),
world_size,
"lazy" if lazy_offload else "eager",
)
if role == KVConnectorRole.SCHEDULER:
self.scheduler_manager = SimpleCPUOffloadScheduler(
vllm_config,
kv_cache_config,
cpu_capacity_per_rank,
lazy_offload=lazy_offload,
)
elif role == KVConnectorRole.WORKER:
self.worker_handler = SimpleCPUOffloadWorker(
vllm_config, kv_cache_config, cpu_capacity_per_rank
)
# --- Worker-side methods ---
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]) -> None:
if self.worker_handler is not None:
self.worker_handler.register_kv_caches(kv_caches)
def bind_connector_metadata(
self,
connector_metadata: KVConnectorMetadata,
) -> None:
super().bind_connector_metadata(connector_metadata)
if self.worker_handler is not None:
assert isinstance(connector_metadata, SimpleCPUOffloadMetadata)
self.worker_handler.bind_connector_metadata(connector_metadata)
def clear_connector_metadata(self) -> None:
super().clear_connector_metadata()
if self.worker_handler is not None:
self.worker_handler.clear_connector_metadata()
def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata) -> None:
if self.worker_handler is not None:
assert isinstance(kv_connector_metadata, SimpleCPUOffloadMetadata)
self.worker_handler.handle_preemptions(kv_connector_metadata)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
pass # Launch loads ops in get_finished() after launching model execution
def wait_for_layer_load(self, layer_name: str) -> None:
pass # Always load asynchronously and deferred to get_finished()
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs: Any,
) -> None:
pass # Always save asynchronously and deferred to get_finished()
def wait_for_save(self) -> None:
pass # All stores are driven by get_finished() and no wait needed
def get_finished(
self,
finished_req_ids: set[str],
) -> tuple[set[str] | None, set[str] | None]:
if self.worker_handler is not None:
return self.worker_handler.get_finished(finished_req_ids)
return None, None
def build_connector_worker_meta(self):
if self.worker_handler is not None:
return self.worker_handler.build_connector_worker_meta()
return None
# --- Scheduler-side methods ---
# NOTE: New API only for SimpleCPUOffloadConnector.
def bind_gpu_block_pool(self, gpu_block_pool: "BlockPool") -> None:
if self.scheduler_manager is not None:
self.scheduler_manager.bind_gpu_block_pool(gpu_block_pool)
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int | None, bool]:
if self.scheduler_manager is not None:
return self.scheduler_manager.get_num_new_matched_tokens(
request, num_computed_tokens
)
return 0, False
def update_state_after_alloc(
self,
request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int,
) -> None:
if self.scheduler_manager is not None:
self.scheduler_manager.update_state_after_alloc(
request, blocks, num_external_tokens
)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
if self.scheduler_manager is not None:
return self.scheduler_manager.build_connector_meta(scheduler_output)
return SimpleCPUOffloadMetadata()
def update_connector_output(
self,
connector_output: KVConnectorOutput,
) -> None:
if self.scheduler_manager is not None:
self.scheduler_manager.update_connector_output(connector_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
if self.scheduler_manager is not None:
return self.scheduler_manager.request_finished(request, block_ids)
return False, None
def request_finished_all_groups(
self,
request: "Request",
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
if self.scheduler_manager is not None:
return self.scheduler_manager.request_finished_all_groups(
request, block_ids
)
return False, None
# NOTE: New API only for SimpleCPUOffloadConnector.
def has_pending_transfers(self) -> bool:
if self.scheduler_manager is not None:
return self.scheduler_manager.has_pending_stores()
return False
def take_events(self) -> Iterable[KVCacheEvent]:
if self.scheduler_manager is not None:
return self.scheduler_manager.take_events()
return []
def reset_cache(self) -> bool | None:
raise NotImplementedError(
"SimpleCPUOffloadConnector does not support reset_cache(). "
"reset_prefix_cache() requires synchronizing all pending "
"CPU offload transfers before clearing GPU prefix cache blocks, "
"which is not yet implemented."
)

View File

@@ -54,6 +54,7 @@ if TYPE_CHECKING:
VLLM_ZENTORCH_WEIGHT_PREPACK: bool = True
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
VLLM_XLA_CHECK_RECOMPILATION: bool = False
VLLM_SPARSE_INDEXER_MAX_LOGITS_MB: int = 512
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto"
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True
@@ -842,6 +843,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
),
# Enable SPMD mode for TPU backend.
"VLLM_XLA_USE_SPMD": lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))),
# Maximum size (in MB) for logits tensor in sparse MLA indexer prefill chunks.
# Bounds the [M, N] float32 logits tensor to prevent CUDA OOM.
# Default: 512 MB
"VLLM_SPARSE_INDEXER_MAX_LOGITS_MB": lambda: int(
os.getenv("VLLM_SPARSE_INDEXER_MAX_LOGITS_MB", "512")
),
# If set, the OpenAI API server will stay alive even after the underlying
# AsyncLLMEngine errors and stops serving requests
"VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": lambda: bool(
@@ -1655,6 +1662,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_XPU_ENABLE_XPU_GRAPH": lambda: bool(
int(os.getenv("VLLM_XPU_ENABLE_XPU_GRAPH", "0"))
),
# Enable simple KV offload.
"VLLM_USE_SIMPLE_KV_OFFLOAD": lambda: bool(
int(os.getenv("VLLM_USE_SIMPLE_KV_OFFLOAD", "0"))
),
}

View File

@@ -12,6 +12,7 @@ from .dual_chunk_rope import DualChunkRotaryEmbedding
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
from .dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding
from .fope import FourierRotaryEmbedding
from .gemma4_rope import Gemma4RotaryEmbedding
from .linear_scaling_rope import LinearScalingRotaryEmbedding
from .llama3_rope import Llama3RotaryEmbedding
from .llama4_vision_rope import Llama4VisionRotaryEmbedding
@@ -134,6 +135,17 @@ def get_rope(
is_neox_style,
dtype,
)
elif scaling_type == "proportional":
# Proportional RoPE is used by Gemma4 for global (full) attention.
# Gemma4 uses a sparse/fractional RoPE with cross-mixing between halves.
rotary_emb = Gemma4RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
)
elif scaling_type == "llama3":
scaling_factor = rope_parameters["factor"]
low_freq_factor = rope_parameters["low_freq_factor"]

View File

@@ -0,0 +1,84 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Gemma4-specific Rotary Positional Embeddings (proportional scaling).
Gemma4 uses "proportional" RoPE which computes inv_freq frequencies scaled
by head_dim (not rotary_dim), and zero-pads for non-rotated dimensions when
partial_rotary_factor < 1. The actual rotation uses standard neox-style
rotate_half, matching HF transformers' apply_rotary_pos_emb.
"""
import torch
from .base import RotaryEmbedding
class Gemma4RotaryEmbedding(RotaryEmbedding):
"""Gemma4 proportional RoPE.
Extends RotaryEmbedding (which provides standard neox-style rotation
via ops.rotary_embedding CUDA kernel) but overrides the inv_freq
computation to match HF's _compute_proportional_rope_parameters:
- Frequency exponents use head_dim (not rotary_dim) as denominator
- Non-rotated dims are zero-padded (cos=1, sin=0 = identity rotation)
When partial_rotary_factor=1.0 (the default for some variants), ALL dims are
rotated and this is equivalent to standard RotaryEmbedding with
head_dim-scaled frequencies.
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
# Number of rotation angle pairs (from partial_rotary_factor)
self.rope_angles = rotary_dim // 2
# Non-rotated angle pairs per half
self.nope_angles = (head_size // 2) - self.rope_angles
# Important: set rotary_dim = head_size so the base class's
# forward_static applies rotation to ALL dims of the cos/sin cache.
# The non-rotated dims will have cos=1, sin=0 (identity) thanks
# to our _compute_inv_freq zero-padding.
super().__init__(
head_size,
head_size, # rotary_dim = head_size (full application)
max_position_embeddings,
base,
is_neox_style,
dtype,
)
def _compute_inv_freq(self, base: float) -> torch.Tensor:
"""Compute frequencies matching HF proportional RoPE.
Key difference from base: exponent denominator is head_size (not
rotary_dim), and non-rotated dims are zero-padded.
"""
# HF formula: base ** (arange(0, 2*rope_angles, 2) / head_dim)
freq_exponents = (
torch.arange(0, 2 * self.rope_angles, 2, dtype=torch.float) / self.head_size
)
inv_freq = 1.0 / (base**freq_exponents)
# Zero-pad for non-rotated dims (identity rotation: cos=1, sin=0)
if self.nope_angles > 0:
inv_freq = torch.cat(
[
inv_freq,
torch.zeros(self.nope_angles, dtype=torch.float),
]
)
return inv_freq
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", rope_angles={self.rope_angles}, nope_angles={self.nope_angles}"
s += f", max_position_embeddings={self.max_position_embeddings}"
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
return s

View File

@@ -4,6 +4,7 @@
import torch
import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
@@ -51,6 +52,14 @@ def sparse_attn_indexer(
((total_seq_lens, head_dim), torch.float8_e4m3fn),
((total_seq_lens, 4), torch.uint8),
)
# Dummy allocation to simulate for peak logits tensor memory during inference.
# FP8 elements so elements == bytes
max_logits_elems = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024
_ = torch.empty(
max_logits_elems, dtype=torch.uint8, device=hidden_states.device
)
return sparse_attn_indexer_fake(
hidden_states,
k_cache_prefix,
@@ -101,13 +110,16 @@ def sparse_attn_indexer(
for chunk in prefill_metadata.chunks:
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
)
if not chunk.skip_kv_gather:
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
)
logits = fp8_mqa_logits(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32).flatten()),

View File

@@ -14,6 +14,7 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentio
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
logger = init_logger(__name__)
@@ -57,6 +58,58 @@ class Gemma3TextModelConfig(VerifyAndUpdateConfig):
hf_config.is_causal = not hf_config.use_bidirectional_attention
class Gemma4Config(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
"""Force unified attention backend for models with heterogeneous
head dimensions.
Some Gemma4 variants use different head dimensions for
sliding window (head_dim) vs full attention (global_head_dim) layers.
When global_head_dim > 256, FlashAttention rejects those layers
(head_size <= 256 kernel limit), causing vLLM to select a different
backend for each layer type. This mixed-backend execution produces
numerical divergence and output corruption.
The fix detects heterogeneous head dimensions from the model config
and forces TRITON_ATTN (which has no head_size ceiling) for all
layers when the user hasn't explicitly chosen a backend.
TODO: Heterogeneous head_sizes (head_dim != global_head_dim)
require NixlConnector changes to support per-layer KV transfer
with different head dimensions for prefill-decode disaggregation.
"""
hf_text_config = vllm_config.model_config.hf_text_config
head_dim = getattr(hf_text_config, "head_dim", None)
global_head_dim = getattr(hf_text_config, "global_head_dim", None)
# Only force Triton when head dimensions actually differ AND the
# larger one exceeds FlashAttention's kernel limit (head_size <= 256).
# This avoids unnecessary backend forcing on smaller models where
# the config carries global_head_dim but all layers can still use
# the same FA backend.
max_head_dim = max(head_dim or 0, global_head_dim or 0)
if (
head_dim is not None
and global_head_dim is not None
and head_dim != global_head_dim
and max_head_dim > 256
and vllm_config.attention_config.backend is None
):
from vllm.v1.attention.backends.registry import (
AttentionBackendEnum,
)
vllm_config.attention_config.backend = AttentionBackendEnum.TRITON_ATTN
logger.info(
"Gemma4 model has heterogeneous head dimensions "
"(head_dim=%d, global_head_dim=%d). Forcing TRITON_ATTN "
"backend to prevent mixed-backend numerical divergence.",
head_dim,
global_head_dim,
)
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
@@ -668,6 +721,8 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig, # noqa: E501
"FalconMambaForCausalLM": MambaModelConfig,
"Gemma3TextModel": Gemma3TextModelConfig,
"Gemma4ForCausalLM": Gemma4Config,
"Gemma4ForConditionalGeneration": Gemma4Config,
"GptOssForCausalLM": GptOssForCausalLMConfig,
"GteModel": SnowflakeGteNewModelConfig,
"GteNewForSequenceClassification": GteNewModelConfig,

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,292 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
"""Gemma4 output parsing utilities for offline inference.
Standalone functions that parse decoded model text to extract structured
thinking content and tool calls from Gemma4 models. These are pure-Python
utilities with zero heavy dependencies — they work on raw decoded strings
from any inference backend (vLLM, HuggingFace, TGI, etc.).
Usage with vLLM offline inference::
from vllm import LLM, SamplingParams
from vllm.model_executor.models.gemma4_utils import (
parse_output,
parse_tool_calls,
)
llm = LLM(model="google/gemma-4-it")
outputs = llm.generate(prompt, SamplingParams(...))
text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False)
# Extract thinking / answer (works with or without enable_thinking)
result = parse_output(text)
print(result["thinking"]) # chain-of-thought or None
print(result["answer"]) # final answer
# Extract tool calls
tool_calls = parse_tool_calls(text)
for tc in tool_calls:
print(f"{tc['name']}({tc['arguments']})")
Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users
do not need a transformers dependency for output parsing.
"""
import json
import regex as re
# ---- Thinking Mode Utility ----
# Thinking delimiter tokens as they appear in decoded text.
# Gemma4 uses <|channel> (start) and <channel|> (end) as thinking delimiters.
_THINKING_START_TAG = "<|channel>"
_THINKING_END_TAG = "<channel|>"
# Sentinel tokens that may appear in decoded output.
_TURN_END_TAG = "<turn|>"
def parse_thinking_output(text: str) -> dict[str, str | None]:
"""Parse decoded Gemma4 model output.
Use this on **all** Gemma4 output regardless of whether thinking mode
was enabled. It handles three cases:
1. **Thinking enabled, tags present** — splits on ``<|channel>``/
``<channel|>`` to separate chain-of-thought from the answer and
strips the ``thought\\n`` role label.
2. **Thinking disabled, spurious label** — strips the bare
``thought\\n`` prefix that some Gemma4 models emit even
without thinking mode.
3. **Clean output** — returns the text unchanged.
The answer text is always cleaned of trailing sentinel tokens
(``<turn|>``, ``<eos>``, etc.).
Args:
text: Decoded model output text (from ``tokenizer.decode(...)``).
Returns:
A dict with keys:
- ``"thinking"``: The chain-of-thought text, or ``None`` if no
thinking delimiters were found.
- ``"answer"``: The final answer text.
Example::
>>> from vllm.model_executor.models.gemma4_utils import parse_thinking_output
>>> output_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
>>> result = parse_thinking_output(output_text)
>>> print(result["thinking"]) # chain-of-thought reasoning or None
>>> print(result["answer"]) # final answer
"""
if _THINKING_END_TAG in text:
parts = text.split(_THINKING_END_TAG, 1)
thinking_block = parts[0]
answer = _clean_answer(parts[1])
# Extract thinking content: strip the start tag if present
if _THINKING_START_TAG in thinking_block:
thinking = thinking_block.split(_THINKING_START_TAG, 1)[1]
else:
thinking = thinking_block
# Strip the "thought\n" channel role label the model emits inside
# <|channel>thought\n...<channel|> (analogous to "user\n" in
# <|turn>user\n...<turn|>).
thinking = _strip_thought_label(thinking.strip())
thinking = thinking.strip()
return {"thinking": thinking, "answer": answer}
# No thinking delimiters found.
# Strip spurious "thought\n" role label that some Gemma4 models sometimes
# emit even without thinking mode enabled, then clean trailing tokens.
answer = _strip_thought_label(text)
answer = _clean_answer(answer)
return {"thinking": None, "answer": answer}
def _strip_thought_label(text: str) -> str:
"""Strip the spurious ``thought\\n`` label from the start of text.
Only strips when ``thought`` appears as the very first word followed by
a newline — preserving the word ``thought`` in any other context.
"""
if text.startswith("thought\n"):
return text[len("thought\n") :]
return text
def _clean_answer(text: str) -> str:
"""Clean trailing sentinel tokens from the answer text.
Strips ``<turn|>``, ``<eos>``, and surrounding whitespace that the
model appends at the end of its response.
"""
text = text.strip()
# Strip trailing <turn|> (Gemma4 turn-end marker)
if text.endswith(_TURN_END_TAG):
text = text[: -len(_TURN_END_TAG)].rstrip()
# Strip trailing <eos> if present
if text.endswith("<eos>"):
text = text[:-5].rstrip()
return text
# ---- Tool Call Parsing Utility ----
#
# NOTE: For the OpenAI-compatible API server tool parser (streaming +
# non-streaming), see vllm/tool_parsers/gemma4_tool_parser.py.
# This module provides offline inference utilities for direct user import.
# Tool call delimiter tokens as they appear in decoded text.
# Standard format: <|tool_call>call:name{args}<tool_call|>
_TOOL_CALL_START_TAG = "<|tool_call>"
_TOOL_CALL_END_TAG = "<tool_call|>"
_TOOL_RESPONSE_START_TAG = "<|tool_response>"
# Gemma4 escape token as it appears in decoded text.
_ESCAPE_TOKEN = '<|"|>'
def _parse_tool_arguments(args_str: str) -> dict[str, str]:
"""Parse tool call arguments from the Gemma4 compact format.
Handles the ``key:<|"|>value<|"|>`` format used by Gemma4, with fallback
to heuristic key-value extraction. Also tolerates the slightly different
``key: "value"`` format (space + plain quotes) that some chat templates
produce.
Args:
args_str: Raw argument string from inside ``call:name{...}``.
Returns:
Dictionary of argument name → value.
"""
if not args_str or not args_str.strip():
return {}
# Replace Gemma4 escape tokens with standard quotes.
cleaned = args_str.replace(_ESCAPE_TOKEN, '"')
# Try JSON parsing first (handles nested values, arrays, etc.).
try:
parsed = json.loads("{" + cleaned + "}")
# Ensure all values are strings for consistency.
return {k: str(v) if not isinstance(v, str) else v for k, v in parsed.items()}
except (json.JSONDecodeError, ValueError):
pass
# Fallback: extract key:"value" pairs (allow optional space after colon).
arguments = {}
for key, value in re.findall(r'(\w+):\s*"([^"]*)"', cleaned):
arguments[key] = value
if not arguments:
# Last resort: extract key:value pairs (unquoted).
for key, value in re.findall(r"(\w+):\s*([^,}]+)", args_str):
arguments[key] = value.strip().strip('"').replace(_ESCAPE_TOKEN, "")
return arguments
def parse_tool_calls(text: str, *, strict: bool = False) -> list[dict]:
"""Parse tool calls from decoded Gemma4 model output.
Uses a tiered parsing strategy to handle known output variations in
Gemma4 models, which may emit
non-standard tool call formats.
Parsing tiers:
1. **Standard**: ``<|tool_call>call:name{args}<tool_call|>``
(special token IDs 48/49 in decoded text)
2. **Fallback** (when ``strict=False``): bare ``call:name{args}``
patterns, including ``<call>name{args}`` (fragmented tokens from
multimodal inputs)
Args:
text: Decoded model output text (from ``tokenizer.decode(...,
skip_special_tokens=False)``).
strict: If ``True``, only match the standard ``<|tool_call>`` format.
If ``False`` (default), also try fallback patterns for
known Gemma4 output variations.
Returns:
A list of dicts, each with keys:
- ``"name"``: The tool function name (e.g. ``"get_weather"``).
- ``"arguments"``: A dict of argument name → value.
Example::
>>> from vllm.model_executor.models.gemma4_utils import (
... parse_tool_calls
... )
>>> output = tokenizer.decode(outputs[0], skip_special_tokens=False)
>>> tool_calls = parse_tool_calls(output)
>>> for tc in tool_calls:
... print(f"Call: {tc['name']}({tc['arguments']})")
"""
results = []
# Tier 1: Standard format with special tokens.
# <|tool_call>call:name{args}<tool_call|>
# Note: Some Gemma4 models emit <turn|> instead of <tool_call|>.
standard_pattern = r"<\|tool_call\>call:(\w+)\{(.*?)\}(?:<tool_call\|>|<turn\|>)"
for match in re.finditer(standard_pattern, text, re.DOTALL):
name, args_str = match.group(1), match.group(2)
results.append(
{
"name": name,
"arguments": _parse_tool_arguments(args_str),
}
)
if results or strict:
return results
# Tier 2: Fallback for known Gemma4 output variations.
# Matches: <call>name{args}, call:name{args}, or bare call:name{args}<eos>
fallback_pattern = r"(?:<call>|(?:^|\s)call:)(\w+)\{(.*?)\}"
for match in re.finditer(fallback_pattern, text, re.DOTALL):
name, args_str = match.group(1), match.group(2)
results.append(
{
"name": name,
"arguments": _parse_tool_arguments(args_str),
}
)
return results
def has_tool_response_tag(text: str) -> bool:
"""Check if model output properly ends with a tool response tag.
Some Gemma4 models sometimes emit ``<eos>`` instead of
``<|tool_response>`` after a tool call. This helper detects
whether the model used the proper termination, so callers can
decide whether to inject ``<|tool_response>`` into the next prompt.
Args:
text: Decoded model output text.
Returns:
``True`` if the output ends with ``<|tool_response>``
(proper behavior), ``False`` otherwise.
Example::
>>> from vllm.model_executor.models.gemma4_utils import (
... has_tool_response_tag
... )
>>> if not has_tool_response_tag(model_output):
... # Model used <eos> instead — inject <|tool_response> manually
... next_prompt = "<|tool_response>" + tool_result
"""
stripped = text.rstrip()
return stripped.endswith(_TOOL_RESPONSE_START_TAG)

View File

@@ -111,6 +111,7 @@ _TEXT_GENERATION_MODELS = {
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
"Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
"Gemma4ForCausalLM": ("gemma4", "Gemma4ForCausalLM"),
"Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
@@ -381,6 +382,7 @@ _MULTIMODAL_MODELS = {
"gemma3n_mm",
"Gemma3nForConditionalGeneration",
),
"Gemma4ForConditionalGeneration": ("gemma4_mm", "Gemma4ForConditionalGeneration"),
"GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),

View File

@@ -233,8 +233,15 @@ class AutoWeightsLoader:
):
"""
Add tensor names that are not in the model params that may be in the
safetensors, e.g., batch normalization stats.
safetensors, e.g., batch normalization stats and registered buffers.
"""
# Add persistent registered buffers.
# Non-persistent buffers are excluded, matching PyTorch state_dict().
non_persistent = getattr(module, "_non_persistent_buffers_set", set())
for buf_name, buf in module.named_buffers(recurse=False):
if buf_name not in child_params and buf_name not in non_persistent:
child_params[buf_name] = buf
if isinstance(
module,
(

View File

@@ -80,8 +80,15 @@ class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]):
"image/jpeg",
)
if self.num_frames > 0:
frame_parts = data.split(",", self.num_frames)[: self.num_frames]
elif self.num_frames == 0:
raise ValueError("num_frames must be greater than 0 or -1")
else:
frame_parts = data.split(",")
frames = np.stack(
[np.asarray(load_frame(frame_data)) for frame_data in data.split(",")]
[np.asarray(load_frame(frame_data)) for frame_data in frame_parts]
)
total = int(frames.shape[0])
fps = float(self.kwargs.get("fps", 1))

View File

@@ -32,6 +32,10 @@ _REASONING_PARSERS_TO_REGISTER = {
"ernie45_reasoning_parser",
"Ernie45ReasoningParser",
),
"gemma4": (
"gemma4_reasoning_parser",
"Gemma4ReasoningParser",
),
"glm45": (
"deepseek_v3_reasoning_parser",
"DeepSeekV3ReasoningWithThinkingParser",

View File

@@ -0,0 +1,193 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import TYPE_CHECKING
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
from vllm.tokenizers import TokenizerLike
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
# Role label that Gemma4 emits at the start of the thinking channel.
# The model generates: <|channel>thought\n...reasoning...<channel|>
# This prefix must be stripped to expose only the actual reasoning content.
_THOUGHT_PREFIX = "thought\n"
class Gemma4ReasoningParser(BaseThinkingReasoningParser):
"""
Reasoning parser for Google Gemma4 thinking models.
Gemma4 uses <|channel>...<channel|> tokens to delimit reasoning/thinking
content within its output. Thinking mode is activated by passing
``enable_thinking=True`` in the chat template kwargs, which injects a
system turn containing <|think|> (token 98) to trigger chain-of-thought
reasoning.
Output pattern when thinking is enabled::
<|channel>thought
...chain of thought reasoning...<channel|>
Final answer text here.
The ``thought\\n`` role label inside the channel delimiters is a
structural artefact (analogous to ``user\\n`` in ``<|turn>user\\n...``).
This parser strips it so that downstream consumers see only the
actual reasoning text, consistent with the offline parser
(``vllm.reasoning.gemma4_utils._strip_thought_label``).
"""
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
# Instance state for streaming prefix stripping.
# Tracks only the reasoning text received from the base parser,
# independent of current_text (which may contain pre-reasoning
# content and lacks special token text due to
# skip_special_tokens=True).
self._reasoning_text: str = ""
self._prefix_stripped: bool = False
@property
def start_token(self) -> str:
"""The token that starts reasoning content."""
return "<|channel>"
@property
def end_token(self) -> str:
"""The token that ends reasoning content."""
return "<channel|>"
# ------------------------------------------------------------------
# Non-streaming path
# ------------------------------------------------------------------
def extract_reasoning(
self,
model_output: str,
request: "ChatCompletionRequest | ResponsesRequest",
) -> tuple[str | None, str | None]:
"""Extract reasoning, stripping the ``thought\\n`` role label."""
if self.start_token not in model_output and self.end_token not in model_output:
# Default to content history if no tags are present
# (or if they were stripped)
return None, model_output
reasoning, content = super().extract_reasoning(model_output, request)
if reasoning is not None:
reasoning = _strip_thought_label(reasoning)
return reasoning, content
# ------------------------------------------------------------------
# Streaming path
# ------------------------------------------------------------------
def extract_reasoning_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
"""Extract streaming reasoning, stripping ``thought\\n`` from the
first reasoning delta(s).
The ``thought\\n`` prefix may arrive as a single delta or split
across multiple deltas (e.g. ``"thought"`` then ``"\\n"``). We
buffer early reasoning tokens until we can determine whether the
prefix is present, then emit the buffered content minus the
prefix.
Unlike the previous implementation which reconstructed accumulated
reasoning from ``current_text``, this uses instance state
(``_reasoning_text``) to track only the reasoning content returned
by the base parser. This is necessary because
``skip_special_tokens=True`` (the vLLM default) causes the
``<|channel>`` delimiter to be invisible in ``current_text``,
making it impossible to separate pre-reasoning content from
reasoning content via string matching.
"""
result = super().extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
)
if result is None:
return None
if result.reasoning is None:
return result
# Accumulate ONLY the reasoning text from base parser results.
# This is immune to pre-reasoning content pollution.
self._reasoning_text += result.reasoning
# Once the prefix has been handled, all subsequent reasoning
# deltas pass through unchanged.
if self._prefix_stripped:
return result
# ---- Prefix stripping logic ----
# Case 1: We've accumulated enough to confirm the prefix is
# present. Strip it and pass through the remainder.
if self._reasoning_text.startswith(_THOUGHT_PREFIX):
prefix_len = len(_THOUGHT_PREFIX)
# How much reasoning was accumulated before this delta?
prev_reasoning_len = len(self._reasoning_text) - len(result.reasoning)
if prev_reasoning_len >= prefix_len:
# Prefix was already consumed by prior deltas; this
# delta is entirely real content — pass through.
self._prefix_stripped = True
return result
else:
# Part or all of the prefix is in this delta.
chars_of_prefix_in_delta = prefix_len - prev_reasoning_len
stripped = result.reasoning[chars_of_prefix_in_delta:]
if stripped:
self._prefix_stripped = True
result.reasoning = stripped
return result
else:
# This entire delta was prefix — suppress it.
# Don't set _prefix_stripped yet; there may be more
# prefix chars to consume in the next delta.
if len(self._reasoning_text) >= prefix_len:
self._prefix_stripped = True
return None
# Case 2: Accumulated text is a strict prefix of
# _THOUGHT_PREFIX (e.g. we've only seen "thou" so far).
# Buffer by suppressing — we can't yet tell if this will
# become the full prefix or diverge.
if _THOUGHT_PREFIX.startswith(self._reasoning_text):
return None
# Case 3: Accumulated text doesn't match the thought prefix
# at all. This means prior deltas were buffered (suppressed
# by Case 2) but the text diverged. Re-emit the full
# accumulated text to avoid data loss.
self._prefix_stripped = True
result.reasoning = self._reasoning_text
return result
def _strip_thought_label(text: str) -> str:
"""Remove the ``thought\\n`` role label from the beginning of text.
Mirrors ``vllm.reasoning.gemma4_utils._strip_thought_label`` from the
offline parser.
"""
if text.startswith(_THOUGHT_PREFIX):
return text[len(_THOUGHT_PREFIX) :]
return text

View File

@@ -0,0 +1,130 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
"""Gemma4 thinking/reasoning output parsing utilities for offline inference.
Standalone functions that parse decoded model text to extract structured
thinking content from Gemma4 models. These are pure-Python utilities with
zero heavy dependencies — they work on raw decoded strings from any
inference backend (vLLM, HuggingFace, TGI, etc.).
For the OpenAI-compatible API reasoning parser (streaming +
non-streaming), see ``vllm.reasoning.gemma4_reasoning_parser``.
For tool call parsing, see ``vllm.tool_parsers.gemma4_utils``.
Usage with vLLM offline inference::
from vllm import LLM, SamplingParams
from vllm.reasoning.gemma4_utils import parse_thinking_output
llm = LLM(model="google/gemma-4-it")
outputs = llm.generate(prompt, SamplingParams(...))
text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False)
# Extract thinking / answer (works with or without enable_thinking)
result = parse_thinking_output(text)
print(result["thinking"]) # chain-of-thought or None
print(result["answer"]) # final answer
Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users
do not need a transformers dependency for output parsing.
"""
# ---- Thinking Mode Utility ----
# Thinking delimiter tokens as they appear in decoded text.
# Gemma4 uses <|channel> (start) and <channel|> (end) as thinking delimiters.
_THINKING_START_TAG = "<|channel>"
_THINKING_END_TAG = "<channel|>"
# Sentinel tokens that may appear in decoded output.
_TURN_END_TAG = "<turn|>"
def parse_thinking_output(text: str) -> dict[str, str | None]:
"""Parse decoded Gemma4 model output.
Use this on **all** Gemma4 output regardless of whether thinking mode
was enabled. It handles three cases:
1. **Thinking enabled, tags present** — splits on ``<|channel>``/
``<channel|>`` to separate chain-of-thought from the answer and
strips the ``thought\\n`` role label.
2. **Thinking disabled, spurious label** — strips the bare
``thought\\n`` prefix that some Gemma4 models emit even
without thinking mode.
3. **Clean output** — returns the text unchanged.
The answer text is always cleaned of trailing sentinel tokens
(``<turn|>``, ``<eos>``, etc.).
Args:
text: Decoded model output text (from ``tokenizer.decode(...)``).
Returns:
A dict with keys:
- ``"thinking"``: The chain-of-thought text, or ``None`` if no
thinking delimiters were found.
- ``"answer"``: The final answer text.
Example::
>>> from vllm.reasoning.gemma4_utils import parse_thinking_output
>>> output_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
>>> result = parse_thinking_output(output_text)
>>> print(result["thinking"]) # chain-of-thought reasoning or None
>>> print(result["answer"]) # final answer
"""
if _THINKING_END_TAG in text:
parts = text.split(_THINKING_END_TAG, 1)
thinking_block = parts[0]
answer = _clean_answer(parts[1])
# Extract thinking content: strip the start tag if present
if _THINKING_START_TAG in thinking_block:
thinking = thinking_block.split(_THINKING_START_TAG, 1)[1]
else:
thinking = thinking_block
# Strip the "thought\n" channel role label the model emits inside
# <|channel>thought\n...<channel|> (analogous to "user\n" in
# <|turn>user\n...<turn|>).
thinking = _strip_thought_label(thinking.strip())
thinking = thinking.strip()
return {"thinking": thinking, "answer": answer}
# No thinking delimiters found.
# Strip spurious "thought\n" role label that some Gemma4 models sometimes
# emit even without thinking mode enabled, then clean trailing tokens.
answer = _strip_thought_label(text)
answer = _clean_answer(answer)
return {"thinking": None, "answer": answer}
def _strip_thought_label(text: str) -> str:
"""Strip the spurious ``thought\\n`` label from the start of text.
Only strips when ``thought`` appears as the very first word followed by
a newline — preserving the word ``thought`` in any other context.
"""
if text.startswith("thought\n"):
return text[len("thought\n") :]
return text
def _clean_answer(text: str) -> str:
"""Clean trailing sentinel tokens from the answer text.
Strips ``<turn|>``, ``<eos>``, and surrounding whitespace that the
model appends at the end of its response.
"""
text = text.strip()
# Strip trailing <turn|> (Gemma4 turn-end marker)
if text.endswith(_TURN_END_TAG):
text = text[: -len(_TURN_END_TAG)].rstrip()
# Strip trailing <eos> if present
if text.endswith("<eos>"):
text = text[:-5].rstrip()
return text

View File

@@ -154,6 +154,10 @@ _TOOL_PARSERS_TO_REGISTER = {
"functiongemma_tool_parser",
"FunctionGemmaToolParser",
),
"gemma4": (
"gemma4_tool_parser",
"Gemma4ToolParser",
),
}

View File

@@ -0,0 +1,724 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tool call parser for Google Gemma4 models.
Gemma4 uses a custom serialization format (not JSON) for tool calls::
<|tool_call>call:func_name{key:<|"|>value<|"|>,num:42}<tool_call|>
Strings are delimited by ``<|"|>`` (token 52), keys are unquoted, and
multiple tool calls are concatenated without separators.
Used when ``--enable-auto-tool-choice --tool-call-parser gemma4`` are set.
For offline inference tool call parsing (direct ``tokenizer.decode()`` output),
see ``vllm.tool_parsers.gemma4_utils.parse_tool_calls``.
"""
import json
from collections.abc import Sequence
import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
from vllm.tool_parsers.utils import find_common_prefix
logger = init_logger(__name__)
# Gemma4 special tokens for tool calls
TOOL_CALL_START = "<|tool_call>"
TOOL_CALL_END = "<tool_call|>"
STRING_DELIM = '<|"|>'
# ---------------------------------------------------------------------------
# Gemma4 argument parser (used by both streaming and non-streaming paths)
# ---------------------------------------------------------------------------
def _parse_gemma4_value(value_str: str) -> object:
"""Parse a single Gemma4 value (after key:) into a Python object."""
value_str = value_str.strip()
if not value_str:
return value_str
# Boolean
if value_str == "true":
return True
if value_str == "false":
return False
# Number (int or float)
try:
if "." in value_str:
return float(value_str)
return int(value_str)
except ValueError:
pass
# Bare string (no <|"|> delimiters — shouldn't happen but be safe)
return value_str
def _parse_gemma4_args(args_str: str) -> dict:
"""Parse Gemma4's custom key:value format into a Python dict.
Format examples::
location:<|"|>Tokyo<|"|>
location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>
count:42,flag:true
nested:{inner_key:<|"|>val<|"|>}
items:[<|"|>a<|"|>,<|"|>b<|"|>]
Returns a dict ready for ``json.dumps()``.
"""
if not args_str or not args_str.strip():
return {}
result: dict = {}
i = 0
n = len(args_str)
while i < n:
# Skip whitespace and commas
while i < n and args_str[i] in (" ", ",", "\n", "\t"):
i += 1
if i >= n:
break
# Parse key (unquoted, ends at ':')
key_start = i
while i < n and args_str[i] != ":":
i += 1
if i >= n:
break
key = args_str[key_start:i].strip()
i += 1 # skip ':'
# Parse value
if i >= n:
result[key] = ""
break
# Skip whitespace after ':'
while i < n and args_str[i] in (" ", "\n", "\t"):
i += 1
if i >= n:
result[key] = ""
break
# String value: <|"|>...<|"|>
if args_str[i:].startswith(STRING_DELIM):
i += len(STRING_DELIM)
val_start = i
end_pos = args_str.find(STRING_DELIM, i)
if end_pos == -1:
# Unterminated string — take rest
result[key] = args_str[val_start:]
break
result[key] = args_str[val_start:end_pos]
i = end_pos + len(STRING_DELIM)
# Nested object: {...}
elif args_str[i] == "{":
depth = 1
obj_start = i + 1
i += 1
while i < n and depth > 0:
if args_str[i:].startswith(STRING_DELIM):
# Skip over string contents to avoid counting { inside strings
i += len(STRING_DELIM)
next_delim = args_str.find(STRING_DELIM, i)
i = n if next_delim == -1 else next_delim + len(STRING_DELIM)
continue
if args_str[i] == "{":
depth += 1
elif args_str[i] == "}":
depth -= 1
i += 1
result[key] = _parse_gemma4_args(args_str[obj_start : i - 1])
# Array: [...]
elif args_str[i] == "[":
depth = 1
arr_start = i + 1
i += 1
while i < n and depth > 0:
if args_str[i:].startswith(STRING_DELIM):
i += len(STRING_DELIM)
next_delim = args_str.find(STRING_DELIM, i)
i = n if next_delim == -1 else next_delim + len(STRING_DELIM)
continue
if args_str[i] == "[":
depth += 1
elif args_str[i] == "]":
depth -= 1
i += 1
arr_content = args_str[arr_start : i - 1]
result[key] = _parse_gemma4_array(arr_content)
# Bare value (number, boolean, etc.)
else:
val_start = i
while i < n and args_str[i] not in (",", "}", "]"):
i += 1
result[key] = _parse_gemma4_value(args_str[val_start:i])
return result
def _parse_gemma4_array(arr_str: str) -> list:
"""Parse a Gemma4 array content string into a Python list."""
items: list = []
i = 0
n = len(arr_str)
while i < n:
while i < n and arr_str[i] in (" ", ",", "\n", "\t"):
i += 1
if i >= n:
break
# String element
if arr_str[i:].startswith(STRING_DELIM):
i += len(STRING_DELIM)
end_pos = arr_str.find(STRING_DELIM, i)
if end_pos == -1:
items.append(arr_str[i:])
break
items.append(arr_str[i:end_pos])
i = end_pos + len(STRING_DELIM)
# Nested object
elif arr_str[i] == "{":
depth = 1
obj_start = i + 1
i += 1
while i < n and depth > 0:
if arr_str[i:].startswith(STRING_DELIM):
i += len(STRING_DELIM)
nd = arr_str.find(STRING_DELIM, i)
i = nd + len(STRING_DELIM) if nd != -1 else n
continue
if arr_str[i] == "{":
depth += 1
elif arr_str[i] == "}":
depth -= 1
i += 1
items.append(_parse_gemma4_args(arr_str[obj_start : i - 1]))
# Nested array
elif arr_str[i] == "[":
depth = 1
sub_start = i + 1
i += 1
while i < n and depth > 0:
if arr_str[i] == "[":
depth += 1
elif arr_str[i] == "]":
depth -= 1
i += 1
items.append(_parse_gemma4_array(arr_str[sub_start : i - 1]))
# Bare value
else:
val_start = i
while i < n and arr_str[i] not in (",", "]"):
i += 1
items.append(_parse_gemma4_value(arr_str[val_start:i]))
return items
# ---------------------------------------------------------------------------
# Parser
# ---------------------------------------------------------------------------
class Gemma4ToolParser(ToolParser):
"""
Tool call parser for Google Gemma4 models.
Handles the Gemma4 function call format::
<|tool_call>call:func_name{key:<|"|>value<|"|>}<tool_call|>
Used when ``--enable-auto-tool-choice --tool-call-parser gemma4``
are set.
Streaming strategy: **accumulate-then-parse-then-diff**
Instead of trying to convert Gemma4's custom format to JSON
token-by-token (which fails because Gemma4 uses bare keys, custom
delimiters, and structural braces that differ from JSON), this parser:
1. Accumulates the raw Gemma4 argument string during streaming
2. Parses it with ``_parse_gemma4_args()`` into a Python dict
3. Converts to JSON with ``json.dumps()``
4. Diffs against the previously-streamed JSON string
5. Emits only the new JSON fragment as the delta
This follows the same pattern used by FunctionGemma, Hermes, and Llama
tool parsers.
"""
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
# Token strings
self.tool_call_start_token = TOOL_CALL_START
self.tool_call_end_token = TOOL_CALL_END
# Token IDs
self.tool_call_start_token_id = self.vocab.get(TOOL_CALL_START)
self.tool_call_end_token_id = self.vocab.get(TOOL_CALL_END)
if self.tool_call_start_token_id is None:
raise RuntimeError(
"Gemma4 ToolParser could not locate the tool call start "
f"token '{TOOL_CALL_START}' in the tokenizer!"
)
# Regex for non-streaming: extract complete tool calls.
# Supports function names with letters, digits, underscores,
# hyphens, and dots (e.g. "get-weather", "module.func").
self.tool_call_regex = re.compile(
r"<\|tool_call>call:([\w\-\.]+)\{(.*?)\}<tool_call\|>",
re.DOTALL,
)
# Streaming state — reset per-request via _reset_streaming_state()
self._reset_streaming_state()
# Delta buffer for handling multi-token special sequences
self.buffered_delta_text = ""
def _reset_streaming_state(self) -> None:
"""Reset all streaming state for a new request."""
self.current_tool_id = -1
self.current_tool_name_sent = False
self.prev_tool_call_arr: list[dict] = []
self.streamed_args_for_tool: list[str] = []
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request)
if (
isinstance(request, ChatCompletionRequest)
and request.tools
and request.tool_choice != "none"
):
# Don't skip special tokens — <|tool_call> etc. are needed
request.skip_special_tokens = False
return request
# ------------------------------------------------------------------
# Delta buffering for multi-token special sequences
# ------------------------------------------------------------------
def _buffer_delta_text(self, delta_text: str) -> str:
"""Buffer incoming delta text to handle multi-token special sequences.
Accumulates partial tokens that could be the start of
``<|tool_call>`` or ``<tool_call|>`` and only flushes them
when the complete sequence is recognized or the sequence breaks.
This prevents partial special tokens (e.g., ``<|tool``) from being
emitted prematurely as content text.
"""
combined = self.buffered_delta_text + delta_text
# Check if combined ends with a complete special token
if combined.endswith(TOOL_CALL_START) or combined.endswith(TOOL_CALL_END):
self.buffered_delta_text = ""
return combined
# Check if combined ends with a partial prefix of a special token
for tag in [TOOL_CALL_START, TOOL_CALL_END]:
for i in range(1, len(tag)):
if combined.endswith(tag[:i]):
self.buffered_delta_text = combined[-i:]
return combined[:-i]
# No partial match — flush everything
self.buffered_delta_text = ""
return combined
# ------------------------------------------------------------------
# Non-streaming extraction
# ------------------------------------------------------------------
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
try:
matches = self.tool_call_regex.findall(model_output)
if not matches:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
tool_calls: list[ToolCall] = []
for func_name, args_str in matches:
arguments = _parse_gemma4_args(args_str)
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=func_name,
arguments=json.dumps(arguments, ensure_ascii=False),
),
)
)
# Content = text before first tool call (if any)
content_end = model_output.find(self.tool_call_start_token)
content = model_output[:content_end].strip() if content_end > 0 else None
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.exception("Error extracting tool calls from Gemma4 response")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# ------------------------------------------------------------------
# Streaming extraction — accumulate-then-parse-then-diff
# ------------------------------------------------------------------
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
# Buffer delta text to handle multi-token special sequences
delta_text = self._buffer_delta_text(delta_text)
# Reconstruct current_text after buffering to stay in sync
current_text = previous_text + delta_text
# If no tool call token seen yet, emit as content
if self.tool_call_start_token not in current_text:
if delta_text:
return DeltaMessage(content=delta_text)
return None
try:
return self._extract_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
)
except Exception:
logger.exception("Error in Gemma4 streaming tool call extraction")
return None
def _extract_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
) -> DeltaMessage | None:
"""Tag-counting streaming parser.
Uses the proven approach from FunctionGemma/Hermes: count start/end
tags in previous vs current text to determine phase, then
accumulate-parse-diff for arguments.
Format: ``<|tool_call>call:name{args}<tool_call|>``
"""
start_count = current_text.count(self.tool_call_start_token)
end_count = current_text.count(self.tool_call_end_token)
prev_start_count = previous_text.count(self.tool_call_start_token)
prev_end_count = previous_text.count(self.tool_call_end_token)
# Case 1: Not inside any tool call — emit as content
if (
start_count == end_count
and prev_end_count == end_count
and self.tool_call_end_token not in delta_text
):
if delta_text:
return DeltaMessage(content=delta_text)
return None
# Case 2: Starting a new tool call
if start_count > prev_start_count and start_count > end_count:
self.current_tool_id += 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
self.prev_tool_call_arr.append({})
logger.debug("Starting new tool call %d", self.current_tool_id)
# Don't return yet — fall through to try parsing if there's
# content after <|tool_call> in this same delta
# (but usually it's just the token itself, so return None)
if len(delta_text) <= len(self.tool_call_start_token):
return None
# Case 3: Tool call just ended
if end_count > prev_end_count:
return self._handle_tool_call_end(current_text)
# Case 4: In the middle of a tool call — parse partial content
if start_count > end_count:
return self._handle_tool_call_middle(current_text)
# Default: generate text outside tool calls
if delta_text:
text = delta_text.replace(self.tool_call_start_token, "")
text = text.replace(self.tool_call_end_token, "")
if text:
return DeltaMessage(content=text)
return None
def _extract_partial_call(self, current_text: str) -> tuple[str | None, str]:
"""Extract function name and raw argument string from partial text.
Returns (func_name, raw_args_str) or (None, "") if not parseable yet.
"""
# Get the text after the last <|tool_call> token
last_start = current_text.rfind(self.tool_call_start_token)
if last_start == -1:
return None, ""
partial_call = current_text[last_start + len(self.tool_call_start_token) :]
# Strip end token if present
if self.tool_call_end_token in partial_call:
partial_call = partial_call.split(self.tool_call_end_token)[0]
# Expect "call:name{args...}" or "call:name{args...}"
if not partial_call.startswith("call:"):
return None, ""
func_part = partial_call[5:] # skip "call:"
if "{" not in func_part:
# Still accumulating function name, not ready yet
return None, ""
func_name, _, args_part = func_part.partition("{")
func_name = func_name.strip()
# Strip trailing '}' if present (Gemma4 structural brace)
if args_part.endswith("}"):
args_part = args_part[:-1]
return func_name, args_part
def _handle_tool_call_middle(self, current_text: str) -> DeltaMessage | None:
"""Handle streaming when we're inside an active tool call.
Accumulates the raw Gemma4 arguments, parses them into JSON, and
diffs against the previously-streamed JSON to emit only the new
fragment.
"""
func_name, args_part = self._extract_partial_call(current_text)
if func_name is None:
return None
# Step 1: Send function name (once)
if not self.current_tool_name_sent and func_name:
self.current_tool_name_sent = True
self.prev_tool_call_arr[self.current_tool_id] = {
"name": func_name,
"arguments": {},
}
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=make_tool_call_id(),
function=DeltaFunctionCall(
name=func_name,
arguments="",
).model_dump(exclude_none=True),
)
]
)
# Step 2: Parse and diff arguments
if self.current_tool_name_sent and args_part:
return self._emit_argument_diff(args_part)
return None
def _handle_tool_call_end(self, current_text: str) -> DeltaMessage | None:
"""Handle streaming when a tool call has just completed.
Performs a final parse of the complete tool call and flushes
any remaining un-streamed argument fragments.
"""
if self.current_tool_id < 0 or self.current_tool_id >= len(
self.prev_tool_call_arr
):
logger.debug(
"Tool call end detected but no active tool call (current_tool_id=%d)",
self.current_tool_id,
)
return None
# Parse the complete tool call using regex for accuracy
all_matches = self.tool_call_regex.findall(current_text)
if self.current_tool_id < len(all_matches):
_, args_str = all_matches[self.current_tool_id]
final_args = _parse_gemma4_args(args_str)
final_args_json = json.dumps(final_args, ensure_ascii=False)
prev_streamed = self.streamed_args_for_tool[self.current_tool_id]
if len(final_args_json) > len(prev_streamed):
diff = final_args_json[len(prev_streamed) :]
self.streamed_args_for_tool[self.current_tool_id] = final_args_json
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = final_args
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=diff).model_dump(
exclude_none=True
),
)
]
)
return None
def _emit_argument_diff(self, raw_args_str: str) -> DeltaMessage | None:
"""Parse raw Gemma4 arguments, convert to JSON, diff, and emit.
This is the core of the accumulate-then-parse-then-diff strategy:
1. Parse ``raw_args_str`` with ``_parse_gemma4_args()``
2. Convert to JSON string with ``json.dumps()``
3. Withhold trailing closing characters (``"}``) that may move
as more tokens arrive
4. Diff against previously streamed JSON and emit only new chars
**Why withholding is necessary:**
Gemma4's custom format produces *structurally incomplete* JSON
during streaming. For example, when ``<|"|>Paris`` arrives
without a closing delimiter, ``_parse_gemma4_args`` treats it
as a complete value and produces ``{"location": "Paris"}``. But
when ``, France<|"|>`` arrives next, the JSON becomes
``{"location": "Paris, France"}``. If we had sent the closing
``"}`` from the first parse, the concatenated client output
would be ``{"location": "Paris"}France"}``, which is garbage.
The solution: **never send trailing closing chars during
streaming**. They get flushed by ``_handle_tool_call_end()``
when the ``<tool_call|>`` end marker arrives.
Args:
raw_args_str: The raw Gemma4 argument text accumulated so far
(without the surrounding ``{`` ``}``).
Returns:
DeltaMessage with the argument diff, or None if no new content.
"""
try:
current_args = _parse_gemma4_args(raw_args_str)
except Exception:
logger.debug(
"Could not parse partial Gemma4 args yet: %s",
raw_args_str[:100],
)
return None
if not current_args:
return None
current_args_json = json.dumps(current_args, ensure_ascii=False)
# Withhold trailing closing characters that may shift as more
# tokens arrive. Strip trailing '}', '"', and ']' sequences
# to get the "safe prefix".
safe_json = current_args_json
while safe_json and safe_json[-1] in ("}", '"', "]"):
safe_json = safe_json[:-1]
prev_streamed = self.streamed_args_for_tool[self.current_tool_id]
if not safe_json or safe_json == prev_streamed:
return None
# Use find_common_prefix to handle cases where the value changed
# structurally (e.g., a string grew).
if prev_streamed:
prefix = find_common_prefix(prev_streamed, safe_json)
sent_len = len(prev_streamed)
prefix_len = len(prefix)
if prefix_len < sent_len:
# Structure changed — we sent too much. Truncate our
# tracking to the common prefix and wait for the final
# flush in _handle_tool_call_end.
self.streamed_args_for_tool[self.current_tool_id] = prefix
return None
# Stream the new stable portion
diff = safe_json[sent_len:]
else:
# First emission
diff = safe_json
if diff:
self.streamed_args_for_tool[self.current_tool_id] = safe_json
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = current_args
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=diff).model_dump(
exclude_none=True
),
)
]
)
return None

View File

@@ -0,0 +1,183 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
"""Gemma4 tool call parsing utilities for offline inference.
Standalone functions that parse decoded model text to extract tool calls
from Gemma4 models. These are pure-Python utilities with zero heavy
dependencies — they work on raw decoded strings from any inference
backend (vLLM, HuggingFace, TGI, etc.).
For the OpenAI-compatible API server tool parser (streaming +
non-streaming), see ``vllm.tool_parsers.gemma4_tool_parser``.
For thinking/reasoning output parsing, see
``vllm.reasoning.gemma4_utils``.
Usage with vLLM offline inference::
from vllm import LLM, SamplingParams
from vllm.tool_parsers.gemma4_utils import (
parse_tool_calls,
has_tool_response_tag,
)
llm = LLM(model="google/gemma-4-it")
outputs = llm.generate(prompt, SamplingParams(...))
text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False)
# Extract tool calls
tool_calls = parse_tool_calls(text)
for tc in tool_calls:
print(f"{tc['name']}({tc['arguments']})")
Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users
do not need a transformers dependency for output parsing.
"""
import json
import regex as re
# Tool call delimiter tokens as they appear in decoded text.
# Standard format: <|tool_call>call:name{args}<tool_call|>
_TOOL_CALL_START_TAG = "<|tool_call>"
_TOOL_CALL_END_TAG = "<tool_call|>"
_TOOL_RESPONSE_START_TAG = "<|tool_response>"
# Gemma4 escape token as it appears in decoded text.
_ESCAPE_TOKEN = '<|"|>'
def _parse_tool_arguments(args_str: str) -> dict[str, str]:
"""Parse tool call arguments from the Gemma4 compact format.
Handles the ``key:<|"|>value<|"|>`` format used by Gemma4, with fallback
to heuristic key-value extraction. Also tolerates the slightly different
``key: "value"`` format (space + plain quotes) that some chat templates
produce.
Args:
args_str: Raw argument string from inside ``call:name{...}``.
Returns:
Dictionary of argument name → value.
"""
if not args_str or not args_str.strip():
return {}
# Replace Gemma4 escape tokens with standard quotes.
cleaned = args_str.replace(_ESCAPE_TOKEN, '"')
# Try JSON parsing first (handles nested values, arrays, etc.).
try:
parsed = json.loads("{" + cleaned + "}")
# Ensure all values are strings for consistency.
return {k: str(v) if not isinstance(v, str) else v for k, v in parsed.items()}
except (json.JSONDecodeError, ValueError):
pass
# Fallback: extract key:"value" pairs (allow optional space after colon).
arguments = {}
for key, value in re.findall(r'(\w+):\s*"([^"]*)"', cleaned):
arguments[key] = value
if not arguments:
# Last resort: extract key:value pairs (unquoted).
for key, value in re.findall(r"(\w+):\s*([^,}]+)", args_str):
arguments[key] = value.strip().strip('"').replace(_ESCAPE_TOKEN, "")
return arguments
def parse_tool_calls(text: str, *, strict: bool = False) -> list[dict]:
"""Parse tool calls from decoded Gemma4 model output.
Uses a tiered parsing strategy to handle known output variations in
Gemma4 models, which may emit
non-standard tool call formats.
Parsing tiers:
1. **Standard**: ``<|tool_call>call:name{args}<tool_call|>``
(special token IDs 48/49 in decoded text)
2. **Fallback** (when ``strict=False``): bare ``call:name{args}``
patterns, including ``<call>name{args}`` (fragmented tokens from
multimodal inputs)
Args:
text: Decoded model output text (from ``tokenizer.decode(...,
skip_special_tokens=False)``).
strict: If ``True``, only match the standard ``<|tool_call>`` format.
If ``False`` (default), also try fallback patterns for
known Gemma4 output variations.
Returns:
A list of dicts, each with keys:
- ``"name"``: The tool function name (e.g. ``"get_weather"``).
- ``"arguments"``: A dict of argument name → value.
Example::
>>> from vllm.tool_parsers.gemma4_utils import parse_tool_calls
>>> output = tokenizer.decode(outputs[0], skip_special_tokens=False)
>>> tool_calls = parse_tool_calls(output)
>>> for tc in tool_calls:
... print(f"Call: {tc['name']}({tc['arguments']})")
"""
results = []
# Tier 1: Standard format with special tokens.
# <|tool_call>call:name{args}<tool_call|>
# Note: Some Gemma4 models emit <turn|> instead of <tool_call|>.
standard_pattern = r"<\|tool_call\>call:(\w+)\{(.*?)\}(?:<tool_call\|>|<turn\|>)"
for match in re.finditer(standard_pattern, text, re.DOTALL):
name, args_str = match.group(1), match.group(2)
results.append(
{
"name": name,
"arguments": _parse_tool_arguments(args_str),
}
)
if results or strict:
return results
# Tier 2: Fallback for known Gemma4 output variations.
# Matches: <call>name{args}, call:name{args}, or bare call:name{args}<eos>
fallback_pattern = r"(?:<call>|(?:^|\s)call:)(\w+)\{(.*?)\}"
for match in re.finditer(fallback_pattern, text, re.DOTALL):
name, args_str = match.group(1), match.group(2)
results.append(
{
"name": name,
"arguments": _parse_tool_arguments(args_str),
}
)
return results
def has_tool_response_tag(text: str) -> bool:
"""Check if model output properly ends with a tool response tag.
Some Gemma4 models sometimes emit ``<eos>`` instead of
``<|tool_response>`` after a tool call. This helper detects
whether the model used the proper termination, so callers can
decide whether to inject ``<|tool_response>`` into the next prompt.
Args:
text: Decoded model output text.
Returns:
``True`` if the output ends with ``<|tool_response>``
(proper behavior), ``False`` otherwise.
Example::
>>> from vllm.tool_parsers.gemma4_utils import has_tool_response_tag
>>> if not has_tool_response_tag(model_output):
... # Model used <eos> instead — inject <|tool_response> manually
... next_prompt = "<|tool_response>" + tool_result
"""
stripped = text.rstrip()
return stripped.endswith(_TOOL_RESPONSE_START_TAG)

View File

@@ -448,6 +448,16 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
return getattr(self.hf_text_config, "num_nextn_predict_layers", 1)
class Gemma4ModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_head_size(self) -> int:
# Gemma4 uses dual head dimensions: head_dim (sliding attention)
# and global_head_dim (full attention). Return the largest so
# that attention backends allocate buffers large enough for both.
head_dim = getattr(self.hf_text_config, "head_dim", 0)
global_head_dim = getattr(self.hf_text_config, "global_head_dim", 0)
return max(head_dim, global_head_dim) or super().get_head_size()
# hf_config.model_type -> convertor class
MODEL_ARCH_CONFIG_CONVERTORS = {
"cohere_asr": CohereAsrModelArchConfigConvertor,
@@ -471,4 +481,6 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
"ernie_mtp": ErnieMTPModelArchConfigConvertor,
"pangu_ultra_moe_mtp": PanguUltraMoeMTPModelArchConfigConvertor,
"longcat_flash_mtp": LongCatFlashMTPModelArchConfigConvertor,
"gemma4": Gemma4ModelArchConfigConvertor,
"gemma4_text": Gemma4ModelArchConfigConvertor,
}

View File

@@ -4,6 +4,7 @@ from dataclasses import dataclass
import torch
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
@@ -22,7 +23,6 @@ from vllm.v1.attention.backend import (
)
from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills,
split_prefill_chunks,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.cp_utils import get_total_cp_world_size
@@ -30,6 +30,55 @@ from vllm.v1.worker.cp_utils import get_total_cp_world_size
logger = init_logger(__name__)
def split_indexer_prefill_chunks(
seq_lens_cpu: torch.Tensor,
query_lens_cpu: torch.Tensor,
workspace_size: int,
max_logits_bytes: int,
request_offset: int = 0,
) -> list[tuple[slice, slice]]:
"""
Split prefill requests into chunks for the sparse indexer, respecting:
- N constraint: total_seq_lens <= workspace_size (existing O(N) workspace)
- Logits constraint: M * N * 4 <= max_logits_bytes
When a single request-level chunk still exceeds the logits budget,
sub-chunks on the query dimension (M) to bound peak memory.
Returns list of (req_slice, query_slice) tuples.
"""
chunks: list[tuple[slice, slice]] = []
n = len(seq_lens_cpu)
max_logits_elems = max_logits_bytes // 4
end = 0
while end < n:
start, chunk_m, chunk_n = end, 0, 0
while end < n:
q, s = query_lens_cpu[end].item(), seq_lens_cpu[end].item()
new_m, new_n = chunk_m + q, chunk_n + s
if new_n <= workspace_size and new_m * new_n <= max_logits_elems:
chunk_m, chunk_n = new_m, new_n
end += 1
else:
break
# A single request can exceed the budget, requiring sub-chunking
# on the query dimension.
if end == start:
chunk_m, chunk_n = query_lens_cpu[end].item(), seq_lens_cpu[end].item()
end += 1
req_slice = slice(start + request_offset, end + request_offset)
max_q = max(1, max_logits_elems // chunk_n) if chunk_n > 0 else chunk_m
for q_off in range(0, chunk_m, max_q):
sub_m = min(max_q, chunk_m - q_off)
chunks.append((req_slice, slice(q_off, q_off + sub_m)))
return chunks
class DeepseekV32IndexerBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
@@ -81,6 +130,7 @@ class DeepseekV32IndexerPrefillChunkMetadata:
token_start: int
token_end: int
num_reqs: int
skip_kv_gather: bool = False
@dataclass
@@ -271,43 +321,51 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
)
def build_one_prefill_chunk(
self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table
):
self,
req_slice: slice,
query_slice: slice,
query_start_loc_cpu,
seq_lens_cpu,
block_table,
skip_kv_gather: bool = False,
) -> DeepseekV32IndexerPrefillChunkMetadata:
prefill_query_start_loc = (
query_start_loc_cpu[reqs_start : reqs_end + 1]
- query_start_loc_cpu[reqs_start]
query_start_loc_cpu[req_slice.start : req_slice.stop + 1]
- query_start_loc_cpu[req_slice.start]
)
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device
prefill_query_start_loc, seq_lens_cpu[req_slice], self.device
)
token_start = query_start_loc_cpu[req_slice.start].item()
total_seq_lens = seq_lens_cpu[req_slice].sum()
num_reqs = req_slice.stop - req_slice.start
seq_idx = torch.arange(0, num_reqs, dtype=torch.int32)
token_to_seq = torch.repeat_interleave(seq_idx, seq_lens_cpu[req_slice]).to(
self.device
)
token_start = query_start_loc_cpu[reqs_start].item()
token_end = query_start_loc_cpu[reqs_end].item()
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
seq_idx = torch.arange(0, reqs_end - reqs_start, dtype=torch.int32)
token_to_seq = torch.repeat_interleave(
seq_idx, seq_lens_cpu[reqs_start:reqs_end]
).to(self.device)
assert total_seq_lens <= self.max_prefill_buffer_size
cu_seq_lens = (
torch.cat(
[
torch.zeros(1, dtype=torch.int32),
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0),
seq_lens_cpu[req_slice].cumsum(dim=0),
]
)
.to(torch.int32)
.to(self.device)
)
return DeepseekV32IndexerPrefillChunkMetadata(
cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke,
cu_seqlen_ks=cu_seqlen_ks[query_slice],
cu_seqlen_ke=cu_seqlen_ke[query_slice],
cu_seq_lens=cu_seq_lens,
token_to_seq=token_to_seq,
total_seq_lens=total_seq_lens,
block_table=block_table[reqs_start:reqs_end],
token_start=token_start,
token_end=token_end,
num_reqs=reqs_end - reqs_start,
block_table=block_table[req_slice],
token_start=token_start + query_slice.start,
token_end=token_start + query_slice.stop,
num_reqs=num_reqs,
skip_kv_gather=skip_kv_gather,
)
def build(
@@ -333,20 +391,27 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
prefill_metadata = None
if num_prefills > 0:
chunk_seq_ids = split_prefill_chunks(
prefill_query_lens_cpu = torch.diff(
query_start_loc_cpu[num_decodes : num_decodes + num_prefills + 1]
)
max_logits_bytes = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024
chunk_specs = split_indexer_prefill_chunks(
common_attn_metadata.seq_lens_cpu[num_decodes:],
prefill_query_lens_cpu,
self.max_prefill_buffer_size,
max_logits_bytes,
request_offset=num_decodes,
)
chunks = [
self.build_one_prefill_chunk(
reqs_start,
reqs_end,
req_slice,
query_slice,
query_start_loc_cpu,
common_attn_metadata.seq_lens_cpu,
common_attn_metadata.block_table_tensor,
skip_kv_gather=query_slice.start > 0,
)
for reqs_start, reqs_end in chunk_seq_ids
for req_slice, query_slice in chunk_specs
]
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
chunks=chunks,

View File

@@ -234,6 +234,13 @@ class Scheduler(SchedulerInterface):
hash_block_size=self.block_size,
metrics_collector=self.kv_metrics_collector,
)
# Bind GPU block pool to the KV connector. This must happen after
# kv_cache_manager is constructed so block_pool is available.
if self.connector is not None and hasattr(
self.connector, "bind_gpu_block_pool"
):
self.connector.bind_gpu_block_pool(self.kv_cache_manager.block_pool)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
self.scheduler_reserve_full_isl = (

View File

@@ -281,8 +281,16 @@ class PromptTokenStats:
self.computed += prompt_len - num_cached_tokens
self.external_kv_transfer += num_external_computed_tokens
self.local_cache_hit += (
num_cached_tokens + recomputed - num_external_computed_tokens
# FIXME(yifan): local_cache_hit can go negative after preemption.
# num_cached_tokens is a one-time snapshot from first scheduling and
# is never reset on preemption, while num_external_computed_tokens is
# overwritten on re-scheduling. If CPU offload finds more tokens on
# the second pass than the original total, the subtraction underflows.
# A fundamental fix is to track the first-time num_external_computed_tokens
# as a separate metric rather than reusing num_external_computed_tokens
# for metric directly.
self.local_cache_hit += max(
0, (num_cached_tokens + recomputed - num_external_computed_tokens)
)
self.cached_tokens += num_cached_tokens
self.recomputed_tokens += recomputed

View File

@@ -303,10 +303,12 @@ class ThinkingTokenBudgetLogitsProcessor(LogitsProcessor):
# Check if thinking is enabled
self.is_enabled = reasoning_config is not None
self.think_start_token_ids = getattr(
reasoning_config, "think_start_token_ids", []
self.reasoning_start_token_ids = getattr(
reasoning_config, "reasoning_start_token_ids", []
)
self.reasoning_end_token_ids = getattr(
reasoning_config, "reasoning_end_token_ids", []
)
self.think_end_token_ids = getattr(reasoning_config, "think_end_token_ids", [])
self.pin_memory = is_pin_memory
self.device = device
@@ -357,15 +359,15 @@ class ThinkingTokenBudgetLogitsProcessor(LogitsProcessor):
think_count = 0
else:
last_start = self._find_last_sequence_index(
prompt_tok_ids, self.think_start_token_ids
prompt_tok_ids, self.reasoning_start_token_ids
)
last_end = self._find_last_sequence_index(
prompt_tok_ids, self.think_end_token_ids
prompt_tok_ids, self.reasoning_end_token_ids
)
in_think = last_start > last_end
if in_think:
think_count = len(prompt_tok_ids) - (
last_start + len(self.think_start_token_ids)
last_start + len(self.reasoning_start_token_ids)
)
else:
think_count = 0
@@ -405,8 +407,8 @@ class ThinkingTokenBudgetLogitsProcessor(LogitsProcessor):
state["prev_output_length"] = current_length
# Check if new tokens contain think start or end sequences
start_len = len(self.think_start_token_ids)
end_len = len(self.think_end_token_ids)
start_len = len(self.reasoning_start_token_ids)
end_len = len(self.reasoning_end_token_ids)
# Look for think sequences in recent tokens (including boundary)
# Check overlapping regions where sequences might span boundaries
@@ -415,10 +417,10 @@ class ThinkingTokenBudgetLogitsProcessor(LogitsProcessor):
# Find any think start/end sequences in recent tokens
recent_start_pos = self._find_last_sequence_index(
recent_tokens, self.think_start_token_ids
recent_tokens, self.reasoning_start_token_ids
)
recent_end_pos = self._find_last_sequence_index(
recent_tokens, self.think_end_token_ids
recent_tokens, self.reasoning_end_token_ids
)
# Update state based on recent sequences
@@ -469,7 +471,7 @@ class ThinkingTokenBudgetLogitsProcessor(LogitsProcessor):
else:
# In end mode
state["end_count"] += 1
if state["end_count"] >= len(self.think_end_token_ids):
if state["end_count"] >= len(self.reasoning_end_token_ids):
state.update(
{
"in_end": False,
@@ -530,7 +532,9 @@ class ThinkingTokenBudgetLogitsProcessor(LogitsProcessor):
state = self._state.get(i)
if state and state["in_end"]:
self.mask[i] = True
self.force_token_ids[i] = self.think_end_token_ids[state["end_count"]]
self.force_token_ids[i] = self.reasoning_end_token_ids[
state["end_count"]
]
# Check in CPU first not to sync with GPU
has_active_thinking = any(

View File

View File

@@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""DMA copy backend for GPU<->CPU block transfers."""
from __future__ import annotations
import queue
import threading
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.simple_kv_offload.cuda_mem_ops import (
BatchMemcpyParams,
build_params,
copy_blocks,
)
logger = init_logger(__name__)
class DmaCopyBackend:
"""cuMemcpyBatchAsync copy backend (background thread)."""
def __init__(self) -> None:
self._store_params: BatchMemcpyParams | None = None
self._load_params: BatchMemcpyParams | None = None
self._load_stream: torch.cuda.Stream | None = None
self._store_stream: torch.cuda.Stream | None = None
self._queue: queue.SimpleQueue | None = None
self._thread: threading.Thread | None = None
self._shutdown: bool = False
def init(
self,
gpu_caches: dict[str, torch.Tensor],
cpu_caches: dict[str, torch.Tensor],
device: torch.device,
load_stream: torch.cuda.Stream,
store_stream: torch.cuda.Stream,
) -> None:
self._load_stream = load_stream
self._store_stream = store_stream
self._store_params = build_params(gpu_caches, cpu_caches, store_stream)
self._load_params = build_params(cpu_caches, gpu_caches, load_stream)
self._queue = queue.SimpleQueue()
self._thread = threading.Thread(
target=self._copy_loop,
args=(self._queue, device, load_stream, store_stream),
daemon=True,
)
self._thread.start()
def launch_copy(
self,
src_blocks: list[int],
dst_blocks: list[int],
is_store: bool,
event_idx: int,
events_list: list[tuple[int, torch.Event]],
) -> None:
params = self._store_params if is_store else self._load_params
assert params is not None and self._queue is not None
self._queue.put(
(src_blocks, dst_blocks, params, is_store, event_idx, events_list)
)
def shutdown(self) -> None:
if self._shutdown:
return
self._shutdown = True
if self._queue is not None:
self._queue.put(None)
if self._thread is not None:
self._thread.join(timeout=5.0)
@staticmethod
def _copy_loop(
q: queue.SimpleQueue,
device: torch.device,
load_stream: torch.cuda.Stream,
store_stream: torch.cuda.Stream,
) -> None:
current_platform.set_device(device)
while True:
item = q.get()
if item is None:
return
src_blocks, dst_blocks, params, is_store, event_idx, events_list = item
copy_blocks(src_blocks, dst_blocks, params)
stream = store_stream if is_store else load_stream
event = torch.Event()
event.record(stream)
events_list.append((event_idx, event))

View File

@@ -0,0 +1,153 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Low-level CUDA memory helpers: pinning and batch DMA transfers."""
import ctypes
from typing import Any, NamedTuple
import numpy as np
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
def pin_tensor(tensor: torch.Tensor) -> None:
"""Pin a CPU tensor via cudaHostRegister.
This bypasses PyTorch's CUDACachingHostAllocator which rounds
every ``pin_memory=True`` allocation up to the next power of 2
(e.g. 100 GB becomes 128 GB).
"""
err = torch.cuda.cudart().cudaHostRegister(tensor.data_ptr(), tensor.nbytes, 0)
if err.value != 0:
raise RuntimeError(f"cudaHostRegister failed: {err}")
class _CUmemLocation(ctypes.Structure):
_fields_ = [("type", ctypes.c_uint), ("id", ctypes.c_int)]
class _CUmemcpyAttributes(ctypes.Structure):
_fields_ = [
("srcAccessOrder", ctypes.c_uint),
("srcLocHint", _CUmemLocation),
("dstLocHint", _CUmemLocation),
("flags", ctypes.c_uint),
]
_BATCH_MEMCPY_FUNC_TYPE = ctypes.CFUNCTYPE(
ctypes.c_uint, # CUresult
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_size_t,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_size_t,
ctypes.c_void_p,
ctypes.c_void_p,
)
# Resolved lazily on first use.
_batch_memcpy_fn: Any = None
def _resolve_batch_memcpy():
"""Resolve cuMemcpyBatchAsync via cuGetProcAddress (one-time)."""
from cuda.bindings import driver as drv
err, ptr, _ = drv.cuGetProcAddress(b"cuMemcpyBatchAsync", 12080, 0)
if err != drv.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"cuGetProcAddress(cuMemcpyBatchAsync) failed: {err}")
return _BATCH_MEMCPY_FUNC_TYPE(ptr)
class BatchMemcpyParams(NamedTuple):
src_bases: np.ndarray # [num_layers] uint64 — data_ptr per layer
dst_bases: np.ndarray # [num_layers] uint64
bpb: np.ndarray # [num_layers] uint64 — bytes per block
num_layers: int
attrs: _CUmemcpyAttributes
attrs_idx: ctypes.c_size_t
# NOTE: cuMemcpyBatchAsync_v2() removed fail_idx field, but we use
# cuMemcpyBatchAsync() with fail_idx for backward compatibility
fail_idx: ctypes.c_size_t
stream_handle: int # raw cudaStream_t / CUstream
def build_params(
src_caches: dict[str, torch.Tensor],
dst_caches: dict[str, torch.Tensor],
stream: torch.cuda.Stream,
) -> BatchMemcpyParams:
global _batch_memcpy_fn
if _batch_memcpy_fn is None:
_batch_memcpy_fn = _resolve_batch_memcpy()
assert list(src_caches.keys()) == list(dst_caches.keys())
src_tensors = list(src_caches.values())
dst_tensors = list(dst_caches.values())
src_bases, dst_bases, bpb = [], [], []
for s, d in zip(src_tensors, dst_tensors):
s_bpb = s.stride(0) * s.element_size()
assert s_bpb == d.stride(0) * d.element_size()
src_bases.append(s.data_ptr())
dst_bases.append(d.data_ptr())
bpb.append(s_bpb)
# Refer to https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g6f1ff58e3065df3eb4b573dba77ad31f for details. # noqa: E501
attrs = _CUmemcpyAttributes(srcAccessOrder=3) # ANY
return BatchMemcpyParams(
src_bases=np.array(src_bases, dtype=np.uint64),
dst_bases=np.array(dst_bases, dtype=np.uint64),
bpb=np.array(bpb, dtype=np.uint64),
num_layers=len(src_tensors),
attrs=attrs,
attrs_idx=ctypes.c_size_t(0),
fail_idx=ctypes.c_size_t(0),
stream_handle=stream.cuda_stream,
)
def copy_blocks(
src_block_ids: list[int],
dst_block_ids: list[int],
params: BatchMemcpyParams,
) -> None:
"""Copy blocks via cuMemcpyBatchAsync."""
n = len(src_block_ids)
if n == 0:
return
src_ids = np.array(src_block_ids, dtype=np.uint64)
dst_ids = np.array(dst_block_ids, dtype=np.uint64)
src_all = (
params.src_bases[:, None] + src_ids[None, :] * params.bpb[:, None]
).ravel()
dst_all = (
params.dst_bases[:, None] + dst_ids[None, :] * params.bpb[:, None]
).ravel()
sz_all = np.repeat(params.bpb, n)
total = n * params.num_layers
err = _batch_memcpy_fn(
dst_all.ctypes.data,
src_all.ctypes.data,
sz_all.ctypes.data,
total,
ctypes.addressof(params.attrs),
ctypes.byref(params.attrs_idx),
1,
ctypes.byref(params.fail_idx),
params.stream_handle,
)
if err != 0:
raise RuntimeError(
f"cuMemcpyBatchAsync failed: err={err} failIdx={params.fail_idx.value}"
)

View File

@@ -0,0 +1,739 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Scheduler-side manager for SimpleCPUOffloadConnector."""
import contextlib
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data
from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_coordinator import (
KVCacheCoordinator,
get_kv_cache_coordinator,
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
MambaSpec,
SlidingWindowSpec,
)
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.simple_kv_offload.metadata import (
SimpleCPUOffloadMetadata,
SimpleCPUOffloadWorkerMetadata,
)
if TYPE_CHECKING:
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import KVCacheBlock
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
@dataclass
class TransferMeta:
gpu_block_ids: list[int]
cpu_block_ids: list[int]
@dataclass
class LoadRequestState:
request: "Request"
transfer_meta: TransferMeta
load_event: int | None = None
finished: bool = False
# NOTE: This per-request state is only used in eager mode.
@dataclass
class StoreRequestState:
request: "Request"
# Accumulated block IDs from scheduler_output via yield_req_data.
block_ids: tuple[list[int], ...]
# Per-group cursors tracking how many blocks have been stored/skipped.
num_stored_blocks: list[int]
store_events: set[int] = field(default_factory=set)
finished: bool = False
class SimpleCPUOffloadScheduler:
"""Scheduler-side manager for CPU offloading."""
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_config: "KVCacheConfig | None",
cpu_capacity_bytes: int,
lazy_offload: bool = False,
):
self.vllm_config = vllm_config
self.kv_cache_config = kv_cache_config
self.enable_kv_cache_events = (
vllm_config.kv_events_config is not None
and vllm_config.kv_events_config.enable_kv_cache_events
)
# NOTE: We use the same block size for both GPU and CPU.
self.block_size = vllm_config.cache_config.block_size
# Derive a CPU KVCacheConfig from the GPU config and build a coordinator
assert kv_cache_config is not None
self.cpu_kv_cache_config = self._derive_cpu_config(
kv_cache_config, cpu_capacity_bytes
)
self.num_cpu_blocks = self.cpu_kv_cache_config.num_blocks
# Find the full attention kv group for prefix cache matching.
self.fa_gidx = -1
for g_idx, g in enumerate(self.cpu_kv_cache_config.kv_cache_groups):
if isinstance(g.kv_cache_spec, FullAttentionSpec):
self.fa_gidx = g_idx
break
assert 0 <= self.fa_gidx < len(self.cpu_kv_cache_config.kv_cache_groups)
logger.info(
"SimpleCPUOffloadScheduler: Allocating %d CPU blocks (%.2f GB, mode=%s)",
self.num_cpu_blocks,
cpu_capacity_bytes / (1024**3),
"lazy" if lazy_offload else "eager",
)
# TODO (yifan): maybe need to enable kv_cache_events and metrics_collector here.
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
assert dcp_world_size == 1 and pcp_world_size == 1
self.cpu_coordinator: KVCacheCoordinator = get_kv_cache_coordinator(
kv_cache_config=self.cpu_kv_cache_config,
max_model_len=vllm_config.model_config.max_model_len,
use_eagle=False,
enable_caching=True,
enable_kv_cache_events=self.enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=self.block_size,
)
self.cpu_block_pool: BlockPool = self.cpu_coordinator.block_pool
# GPU block pool reference - bound after scheduler builds kv_cache_manager
self._gpu_block_pool: BlockPool | None = None
# Load metadata
self._reqs_to_load: dict[str, LoadRequestState] = {}
# Inverse map: load_event_idx -> req_ids. Keyed by load_event_idx because
# the worker reports completions by event index, not request id.
self._load_event_to_reqs: dict[int, list[str]] = {}
# Store metadata
self._lazy_mode = lazy_offload
# Lazy mode: use a cursor to track the last scanned block in the GPU free queue.
self._cursor: KVCacheBlock | None = None
if self._lazy_mode:
self._target_free = self._estimate_lazy_target_blocks(
kv_cache_config,
vllm_config.scheduler_config.max_num_batched_tokens,
)
else:
self._target_free = 0
self._store_event_to_blocks: dict[int, TransferMeta] = {}
# Eager mode only
self._reqs_to_store: dict[str, StoreRequestState] = {}
self._store_event_to_reqs: dict[int, list[str]] = {}
# Event counters
self._load_event_counter: int = 0
self._store_event_counter: int = 0
# For TP/PP: track partial store completions across steps.
# Events must be reported by all world_size workers before considered complete.
self._expected_worker_count = vllm_config.parallel_config.world_size
self._store_event_pending_counts: dict[int, int] = {}
@staticmethod
def _derive_cpu_config(
gpu_config: "KVCacheConfig", cpu_capacity_bytes: int
) -> "KVCacheConfig":
"""Derive a CPU KVCacheConfig from the GPU config.
Same kv_cache_groups, num_blocks scaled by CPU/GPU memory ratio."""
# Import here to avoid potential circular imports
from vllm.v1.kv_cache_interface import KVCacheConfig as KVCacheConfigCls
from vllm.v1.kv_cache_interface import KVCacheTensor
assert len(gpu_config.kv_cache_tensors) > 0
gpu_total_bytes = sum(t.size for t in gpu_config.kv_cache_tensors)
num_gpu_blocks = gpu_config.num_blocks
num_cpu_blocks = max(1, num_gpu_blocks * cpu_capacity_bytes // gpu_total_bytes)
# Create CPU kv_cache_tensors mirroring GPU by scaling size proportionally.
cpu_tensors = [
KVCacheTensor(
size=t.size // num_gpu_blocks * num_cpu_blocks,
shared_by=list(t.shared_by),
)
for t in gpu_config.kv_cache_tensors
]
return KVCacheConfigCls(
num_blocks=num_cpu_blocks,
kv_cache_tensors=cpu_tensors,
kv_cache_groups=gpu_config.kv_cache_groups,
)
@staticmethod
def _estimate_lazy_target_blocks(
kv_cache_config: "KVCacheConfig", max_num_batched_tokens: int
) -> int:
"""GPU blocks to keep available (free/offloaded) per step in lazy mode."""
WATERMARK_RATIO = 1.0 # Reserve larger space to avoid running out of GPU blocks
target = 0
for g in kv_cache_config.kv_cache_groups:
spec = g.kv_cache_spec
if isinstance(spec, MambaSpec):
target += 2
elif isinstance(spec, SlidingWindowSpec):
target += cdiv(spec.sliding_window, spec.block_size) + 1
else:
target += cdiv(max_num_batched_tokens, spec.block_size)
return int(target * (1 + WATERMARK_RATIO))
def bind_gpu_block_pool(self, gpu_block_pool: BlockPool) -> None:
"""Bind GPU block pool so that we can touch blocks during stores.
Called by Scheduler after kv_cache_manager is ready."""
self._gpu_block_pool = gpu_block_pool
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int | None, bool]:
"""Return (num_new_tokens, is_async) from consecutive CPU cache hits."""
skipped = num_computed_tokens // self.block_size
remaining_hashes = request.block_hashes[skipped:]
if not remaining_hashes:
return 0, False
# Must recompute at least the last token, matching the logic in
# kv_cache_manager.get_computed_blocks().
max_hit_len = request.num_tokens - 1 - num_computed_tokens
if max_hit_len <= 0:
return 0, False
_, hit_length = self.cpu_coordinator.find_longest_cache_hit(
remaining_hashes, max_hit_len
)
if hit_length > 0:
return hit_length, True
return 0, False
# TODO(yifan): this API now only matches the suffix part of the prefix cache. A more
# general API should scan blocks in both GPU and CPU block pool in a single pass.
def update_state_after_alloc(
self,
request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int,
) -> None:
req_id = request.request_id
block_ids_by_group = blocks.get_block_ids()
num_groups = len(block_ids_by_group)
# Store tracking (eager mode only). Register the request;
# block IDs are accumulated from scheduler_output in
# _prepare_eager_store_specs via yield_req_data.
if not self._lazy_mode and req_id not in self._reqs_to_store:
self._reqs_to_store[req_id] = StoreRequestState(
request=request,
block_ids=tuple([] for _ in range(num_groups)),
num_stored_blocks=[0] * num_groups,
)
if num_external_tokens == 0:
return
num_blocks_to_load = num_external_tokens // self.block_size
assert num_blocks_to_load > 0
skipped = sum(blk.block_hash is not None for blk in blocks.blocks[self.fa_gidx])
num_computed_tokens = skipped * self.block_size
hashes_to_load = request.block_hashes[skipped : skipped + num_blocks_to_load]
# Find CPU cached blocks across all groups.
max_hit_len = len(hashes_to_load) * self.block_size
cpu_hit_blocks, hit_length = self.cpu_coordinator.find_longest_cache_hit(
hashes_to_load, max_hit_len
)
assert hit_length == num_external_tokens, (
f"Expected {num_external_tokens} hit tokens, got {hit_length}"
)
# Build transfer pairs across all groups.
total_computed_tokens = num_computed_tokens + num_external_tokens
kv_cache_groups = self.cpu_kv_cache_config.kv_cache_groups
gpu_block_ids: list[int] = []
cpu_block_ids: list[int] = []
cpu_blocks_to_touch: list[KVCacheBlock] = []
for g in range(num_groups):
cpu_blocks_g = cpu_hit_blocks[g]
n_ext_g = len(cpu_blocks_g)
if n_ext_g == 0:
continue
# Number of blocks in the computed range for this group.
g_block_size = kv_cache_groups[g].kv_cache_spec.block_size
n_computed_g = cdiv(total_computed_tokens, g_block_size)
# Back-trace: ext blocks sit at the tail of the computed range.
gpu_ext_start = n_computed_g - n_ext_g
group_gpu_ids = block_ids_by_group[g]
for i, cpu_blk in enumerate(cpu_blocks_g):
# Skip null blocks (e.g. sliding window or mamba padding).
if cpu_blk.is_null:
continue
gpu_block_ids.append(group_gpu_ids[gpu_ext_start + i])
cpu_block_ids.append(cpu_blk.block_id)
cpu_blocks_to_touch.append(cpu_blk)
# Touch CPU blocks to prevent eviction during async load.
self.cpu_block_pool.touch(cpu_blocks_to_touch)
# Touch GPU blocks to prevent freeing during async load
assert self._gpu_block_pool is not None
self._gpu_block_pool.touch(
[self._gpu_block_pool.blocks[bid] for bid in gpu_block_ids]
)
assert self._reqs_to_load.get(req_id) is None
self._reqs_to_load[req_id] = LoadRequestState(
request=request, transfer_meta=TransferMeta(gpu_block_ids, cpu_block_ids)
)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> SimpleCPUOffloadMetadata:
# --- Stores ---
store_event = -1
store_gpu, store_cpu, store_req_ids = self.prepare_store_specs(scheduler_output)
if store_gpu:
store_event = self._store_event_counter
self._store_event_counter += 1
self._store_event_to_blocks[store_event] = TransferMeta(
store_gpu, store_cpu
)
if store_req_ids: # For eager mode only, track req->blocks mapping
self._store_event_to_reqs[store_event] = store_req_ids
for req_id in store_req_ids:
store_state = self._reqs_to_store.get(req_id)
if store_state is not None:
store_state.store_events.add(store_event)
# --- Loads ---
load_event = -1
load_gpu: list[int] = []
load_cpu: list[int] = []
load_req_ids: list[str] = []
for req_id, load_state in self._reqs_to_load.items():
if load_state.load_event is not None:
continue
assert load_state.transfer_meta is not None
load_gpu.extend(load_state.transfer_meta.gpu_block_ids)
load_cpu.extend(load_state.transfer_meta.cpu_block_ids)
load_req_ids.append(req_id)
if load_req_ids:
load_event = self._load_event_counter
self._load_event_counter += 1
for req_id in load_req_ids:
self._reqs_to_load[req_id].load_event = load_event
self._load_event_to_reqs[load_event] = load_req_ids
result = SimpleCPUOffloadMetadata(
load_event=load_event,
load_gpu_blocks=load_gpu,
load_cpu_blocks=load_cpu,
load_event_to_reqs=self._load_event_to_reqs,
store_event=store_event,
store_gpu_blocks=store_gpu,
store_cpu_blocks=store_cpu,
need_flush=bool(scheduler_output.preempted_req_ids),
)
return result
def prepare_store_specs(
self, scheduler_output: SchedulerOutput
) -> tuple[list[int], list[int], list[str]]:
"""Prepare store specs for the store event."""
if self._lazy_mode:
return self._prepare_lazy_store_specs()
else:
return self._prepare_eager_store_specs(scheduler_output)
def _prepare_lazy_store_specs(
self,
) -> tuple[list[int], list[int], list[str]]:
"""Single-pass cursor walk: offload cached GPU blocks near eviction.
Walks the GPU free queue from the cursor, counting blocks that are
free-or-offloaded (safe for the allocator to evict). Stops when
target_free blocks are covered or CPU capacity is reached.
"""
gpu_pool = self._gpu_block_pool
if gpu_pool is None or self._target_free <= 0:
return [], [], []
free_queue = gpu_pool.free_block_queue
cpu_pool = self.cpu_block_pool
num_cpu_free = cpu_pool.get_num_free_blocks()
# Validate cursor: stale if block was removed from free queue.
if self._cursor is not None and self._cursor.ref_cnt > 0:
self._cursor = None
# Determine start node.
if self._cursor is None:
node = free_queue.fake_free_list_head.next_free_block
else:
node = self._cursor.next_free_block
tail = free_queue.fake_free_list_tail
gpu_ids: list[int] = []
block_hashes: list[bytes] = []
covered = 0
last_visited = self._cursor
while (
node is not None
and node is not tail
and covered < self._target_free
and len(gpu_ids) < num_cpu_free
):
last_visited = node
bhash = node.block_hash
if (
bhash is not None
and not node.is_null
and cpu_pool.cached_block_hash_to_block.get_one_block(bhash) is None
):
gpu_ids.append(node.block_id)
block_hashes.append(bhash)
covered += 1
node = node.next_free_block
self._cursor = last_visited
# Batch-allocate CPU blocks and stamp hashes.
if gpu_ids:
cpu_blocks = cpu_pool.get_new_blocks(len(gpu_ids))
cpu_ids = [blk.block_id for blk in cpu_blocks]
for cpu_blk, bhash in zip(cpu_blocks, block_hashes): # type: ignore[assignment]
cpu_blk._block_hash = bhash # type: ignore[assignment]
# Touch GPU blocks to prevent eviction during async copy.
gpu_pool.touch([gpu_pool.blocks[bid] for bid in gpu_ids])
else:
cpu_ids = []
return gpu_ids, cpu_ids, []
def _prepare_eager_store_specs(
self, scheduler_output: SchedulerOutput
) -> tuple[list[int], list[int], list[str]]:
"""Identify newly computed blocks to offload from scheduler requests.
Only considers blocks whose KV data has been **confirmed computed** by
the GPU. This means blocks from the current step are NOT stored until the
next step. If a request finishes in the same step as its last full block,
that block may be missed. (TODO: flush on finish.)
Returns:
(gpu_block_ids, cpu_block_ids, req_ids) for the store event.
"""
merged_gpu_block_ids: list[int] = []
merged_cpu_block_ids: list[int] = []
req_ids: list[str] = []
gpu_block_pool = self._gpu_block_pool
if gpu_block_pool is None:
return [], [], []
cpu_block_pool = self.cpu_block_pool
num_free = cpu_block_pool.get_num_free_blocks()
kv_cache_groups = self.cpu_kv_cache_config.kv_cache_groups
num_groups = len(kv_cache_groups)
gpu_blocks_this_step: set[int] = set()
for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
state = self._reqs_to_store.get(req_id)
if state is None or state.finished:
continue
# Accumulate new block IDs.
if preempted:
state.block_ids = tuple([] for _ in range(num_groups))
state.num_stored_blocks = [0] * num_groups
if new_block_id_groups:
for g in range(min(num_groups, len(new_block_id_groups))):
if new_block_id_groups[g] is not None:
state.block_ids[g].extend(new_block_id_groups[g])
num_new_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0)
if num_new_tokens == 0:
continue
block_ids_by_group = state.block_ids
if not block_ids_by_group:
continue
# --- Phase 1: Scan blocks, classify as cached vs to-store ---
gpu_block_ids: list[int] = []
block_hashes_to_store: list[bytes] = []
advanced_per_group: list[int] = [0] * num_groups
out_of_space = False
# Confirmed tokens: KV data written and visible to all streams.
req = state.request
confirmed_tokens = req.num_computed_tokens - req.num_output_placeholders
for g in range(num_groups):
# FIXME (yifan): handle CPU cache eviction, where
# num_stored_blocks can be stale and omit evicted blocks in
# the middle of the request.
already_stored_g = state.num_stored_blocks[g]
group_gpu_ids = block_ids_by_group[g]
# Cap to blocks with confirmed KV data.
g_block_size = kv_cache_groups[g].kv_cache_spec.block_size
ready_blocks_g = confirmed_tokens // g_block_size
scannable = group_gpu_ids[already_stored_g:ready_blocks_g]
for gpu_block_id in scannable:
gpu_block = gpu_block_pool.blocks[gpu_block_id]
if gpu_block.is_null:
advanced_per_group[g] += 1
continue
bhash_with_group = gpu_block.block_hash
if bhash_with_group is None:
break
# Check if this group's data is already scheduled for store
# in this step or already cached in CPU.
if (
gpu_block_id in gpu_blocks_this_step
or cpu_block_pool.cached_block_hash_to_block.get_one_block(
bhash_with_group
)
is not None
):
advanced_per_group[g] += 1
continue
if num_free <= 0:
out_of_space = True
break
num_free -= 1
gpu_block_ids.append(gpu_block_id)
block_hashes_to_store.append(bhash_with_group)
advanced_per_group[g] += 1
if out_of_space:
break
# --- Phase 2: Batch allocate CPU blocks and stamp hashes ---
n_to_alloc = len(gpu_block_ids)
if n_to_alloc > 0:
cpu_blocks_alloc = cpu_block_pool.get_new_blocks(n_to_alloc)
cpu_block_ids = [blk.block_id for blk in cpu_blocks_alloc]
for cpu_blk, bhash in zip(cpu_blocks_alloc, block_hashes_to_store):
cpu_blk._block_hash = bhash # type: ignore[assignment]
else:
cpu_block_ids = []
if cpu_block_ids:
req_ids.append(req_id)
merged_gpu_block_ids.extend(gpu_block_ids)
merged_cpu_block_ids.extend(cpu_block_ids)
gpu_blocks_this_step.update(gpu_block_ids)
# Touch GPU blocks to prevent freeing during async copy
gpu_block_pool.touch(
[gpu_block_pool.blocks[bid] for bid in gpu_block_ids]
)
logger.debug(
"Request %s: Scheduling store of %d blocks to CPU (%d groups)",
req_id,
len(cpu_block_ids),
num_groups,
)
# Advance per-group cursors (includes cached hits + newly stored)
for g in range(num_groups):
state.num_stored_blocks[g] += advanced_per_group[g]
return merged_gpu_block_ids, merged_cpu_block_ids, req_ids
def update_connector_output(self, connector_output: KVConnectorOutput) -> None:
"""Handle async transfer completions from worker.
Load completions arrive via finished_recving (real req_ids).
Store completions arrive via kv_connector_worker_meta as
per-event worker counts. We accumulate across steps and process
a store event only when all workers have reported completion.
"""
# --- Load completions ---
for req_id in list(connector_output.finished_recving or []):
self._cleanup_load_request(req_id)
# --- Store completions ---
meta = connector_output.kv_connector_worker_meta
if not isinstance(meta, SimpleCPUOffloadWorkerMetadata):
return
for event_idx, count in meta.completed_store_events.items():
total = self._store_event_pending_counts.get(event_idx, 0) + count
if total >= self._expected_worker_count:
self._store_event_pending_counts.pop(event_idx, None)
self._process_store_event(event_idx)
else:
self._store_event_pending_counts[event_idx] = total
def _process_store_event(self, event_idx: int) -> None:
"""Process a fully-completed store event."""
transfer = self._store_event_to_blocks.pop(event_idx)
self._process_store_completion(transfer.gpu_block_ids, transfer.cpu_block_ids)
logger.debug(
"Store event %d completed: cached %d blocks to CPU",
event_idx,
len(transfer.cpu_block_ids),
)
# Eager only: update per-req state
if not self._lazy_mode:
for req_id in self._store_event_to_reqs.pop(event_idx, []):
state = self._reqs_to_store.get(req_id)
if state is None:
continue
state.store_events.discard(event_idx)
if state.finished and not state.store_events:
self._cleanup_store_request(req_id)
def _process_store_completion(
self, gpu_block_ids: list[int], cpu_block_ids: list[int]
) -> None:
"""Cache CPU blocks per-group and release GPU refs.
Block hashes were stamped on CPU blocks at allocation time (in
``_prepare_*_store_specs``). Here we just register them in the
cache map so they become discoverable by the load path.
"""
assert len(cpu_block_ids) == len(gpu_block_ids)
cpu_blocks = [self.cpu_block_pool.blocks[bid] for bid in cpu_block_ids]
for cpu_block in cpu_blocks:
bhash = cpu_block.block_hash
assert bhash is not None
self.cpu_block_pool.cached_block_hash_to_block.insert(bhash, cpu_block)
# Free CPU and GPU blocks' ref counts to turn them into prefix cache
self.cpu_block_pool.free_blocks(cpu_blocks)
assert self._gpu_block_pool is not None
self._gpu_block_pool.free_blocks(
self._gpu_block_pool.blocks[bid] for bid in gpu_block_ids
)
def has_pending_stores(self) -> bool:
"""Return True if there are in-flight store transfers."""
return bool(self._store_event_to_blocks)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
"""Always returns (False, None). GPU blocks are protected by ref_cnt,
so the scheduler can free blocks immediately."""
req_id = request.request_id
# Handle load: defer cleanup if load is in-flight
load_state = self._reqs_to_load.get(req_id)
if load_state is not None:
if load_state.load_event is not None:
load_state.finished = True # Defer: load in-flight
else:
self._cleanup_load_request(req_id)
# Handle store (eager mode only): defer cleanup if stores in-flight
if not self._lazy_mode:
store_state = self._reqs_to_store.get(req_id)
if store_state is not None:
if store_state.store_events:
store_state.finished = True # Defer: stores in-flight
else:
self._cleanup_store_request(req_id)
return False, None
def request_finished_all_groups(
self,
request: "Request",
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
return self.request_finished(request, block_ids=[])
def _cleanup_load_request(self, req_id: str) -> None:
"""Release all load resources for a request.
Shared between request_finished() and update_connector_output() paths.
Removes the request from _reqs_to_load, cleans up event mappings,
and frees CPU/GPU touch refs.
"""
state = self._reqs_to_load.pop(req_id, None)
if state is None:
return
# Remove from load event mapping (only this req, not whole event)
if state.load_event is not None:
reqs = self._load_event_to_reqs.get(state.load_event)
if reqs is not None:
with contextlib.suppress(ValueError):
reqs.remove(req_id)
if not reqs:
self._load_event_to_reqs.pop(state.load_event, None)
if state.transfer_meta is not None:
# Free CPU touch refs
self.cpu_block_pool.free_blocks(
self.cpu_block_pool.blocks[bid]
for bid in state.transfer_meta.cpu_block_ids
)
# Free GPU touch refs
assert self._gpu_block_pool is not None
self._gpu_block_pool.free_blocks(
self._gpu_block_pool.blocks[bid]
for bid in state.transfer_meta.gpu_block_ids
)
def _cleanup_store_request(self, req_id: str) -> None:
"""Release store metadata for a request.
Metadata-only cleanup but no block freeing. Job completion handles
block caching and GPU ref freeing via _process_store_completion().
"""
state = self._reqs_to_store.pop(req_id, None)
if state is None:
return
for event_idx in list(state.store_events):
if (reqs := self._store_event_to_reqs.get(event_idx)) is not None:
with contextlib.suppress(ValueError):
reqs.remove(req_id)
if not reqs:
self._store_event_to_reqs.pop(event_idx, None)
state.store_events.clear()
def take_events(self) -> Iterable[KVCacheEvent]:
return self.cpu_block_pool.take_events()

View File

@@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Metadata for SimpleCPUOffloadConnector."""
from dataclasses import dataclass, field
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata,
KVConnectorWorkerMetadata,
)
INVALID_JOB_ID = -1
@dataclass
class SimpleCPUOffloadMetadata(KVConnectorMetadata):
"""
Metadata passed from scheduler to worker for CPU offload operations.
The worker receives flat block lists keyed by a monotonic event_idx.
Job->req_id translation is handled by the scheduler-side manager
(via inverse maps), so the worker never knows about request identities.
"""
# Load event per step. INVALID_JOB_ID means no blocks to load this step.
load_event: int = INVALID_JOB_ID
load_gpu_blocks: list[int] = field(default_factory=list)
load_cpu_blocks: list[int] = field(default_factory=list)
# Reverse map: load_event->req_ids, for tracking requests with finished load events
load_event_to_reqs: dict[int, list[str]] = field(default_factory=dict)
# Store event per step. INVALID_JOB_ID means no blocks to store this step.
store_event: int = INVALID_JOB_ID
store_gpu_blocks: list[int] = field(default_factory=list)
store_cpu_blocks: list[int] = field(default_factory=list)
# Whether any requests were preempted this step and need flush pending transfers.
need_flush: bool = False
@dataclass
class SimpleCPUOffloadWorkerMetadata(KVConnectorWorkerMetadata):
"""Worker -> Scheduler metadata for completed store events.
Each worker reports {event_idx: 1} for newly completed stores.
``aggregate()`` sums counts across workers within a step.
The scheduler-side manager accumulates across steps and processes
a store completion only when count reaches ``world_size``.
"""
completed_store_events: dict[int, int]
def aggregate(
self, other: "KVConnectorWorkerMetadata"
) -> "KVConnectorWorkerMetadata":
assert isinstance(other, SimpleCPUOffloadWorkerMetadata)
merged = dict(self.completed_store_events)
for k, v in other.completed_store_events.items():
merged[k] = merged.get(k, 0) + v
return SimpleCPUOffloadWorkerMetadata(completed_store_events=merged)

View File

@@ -0,0 +1,305 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Worker-side handler for SimpleCPUOffloadConnector."""
from typing import TYPE_CHECKING
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.simple_kv_offload.copy_backend import DmaCopyBackend
from vllm.v1.simple_kv_offload.cuda_mem_ops import pin_tensor
from vllm.v1.simple_kv_offload.metadata import (
SimpleCPUOffloadMetadata,
SimpleCPUOffloadWorkerMetadata,
)
if TYPE_CHECKING:
from vllm.v1.kv_cache_interface import KVCacheConfig
logger = init_logger(__name__)
class SimpleCPUOffloadWorker:
"""Worker-side handler for CPU offloading transfers."""
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_config: "KVCacheConfig | None",
cpu_capacity_bytes: int,
):
self.vllm_config = vllm_config
self.kv_cache_config = kv_cache_config
self.cpu_capacity_bytes = cpu_capacity_bytes
self.gpu_kv_caches: dict[str, torch.Tensor] | None = None
self.cpu_kv_caches: dict[str, torch.Tensor] | None = None
self.device: torch.device | None = None
self.num_cpu_blocks: int = 0
# CUDA streams for the async transfers
self.load_stream: torch.cuda.Stream | None = None
self.store_stream: torch.cuda.Stream | None = None
self._backend = DmaCopyBackend()
# Ordered (event_idx, Event). Events pre-allocated on main thread.
self._load_events: list[tuple[int, torch.Event]] = []
self._store_events: list[tuple[int, torch.Event]] = []
# High-water marks: highest event_idx completed per stream.
# When the event list is empty, the hwm covers all prior events.
self._load_hwm: int = -1
self._store_hwm: int = -1
# Metadata for the current step
self._connector_metadata: SimpleCPUOffloadMetadata | None = None
# Pending event index sets, populated in bind_connector_metadata
self._pending_load_event_indices: set[int] = set()
self._pending_store_event_indices: set[int] = set()
# Completed store events to report via build_connector_worker_meta
self._completed_store_events: dict[int, int] = {}
def register_kv_caches(
self,
kv_caches: dict[str, torch.Tensor],
) -> None:
"""Register GPU KV caches and allocate pinned CPU tensors.
The worker will infer the underlying raw storage from the kv_caches.
Args:
kv_caches: Per-layer GPU KV caches. Values are either a single
tensor (attention layers) or a list of tensors (Mamba layers
in hybrid models). All values are included for offloading
by resolving to their underlying raw storage.
"""
if not kv_caches:
logger.warning("No KV caches to offload.")
return
# Resolve each entry to a representative tensor for storage
# deduplication. For attention layers the value is already a tensor;
# for Mamba layers it is a list of tensors that all share the same
# underlying raw storage, so we take the first one.
def _repr_tensor(v: torch.Tensor | list[torch.Tensor]) -> torch.Tensor:
assert isinstance(v, torch.Tensor | list)
return v if isinstance(v, torch.Tensor) else v[0]
any_tensor = _repr_tensor(next(iter(kv_caches.values())))
self.device = any_tensor.device
assert self.kv_cache_config is not None
num_blocks = self.kv_cache_config.num_blocks
# Deduplicate: multiple layers may share the same backing storage.
seen_ptrs: dict[int, tuple[str, torch.Tensor]] = {}
for name, value in kv_caches.items():
tensor = _repr_tensor(value)
ptr = tensor.untyped_storage().data_ptr()
if ptr not in seen_ptrs:
seen_ptrs[ptr] = (name, tensor)
# Build [num_blocks, block_bytes] int8 views from each unique
# storage so that stride(0) gives block_bytes for the copy op.
#
# The physical layout varies across attention backends:
# FlashAttn/ROCm: (2, num_blocks, ...) -> K/V outermost, 2 segments
# FlashInfer/MLA: (num_blocks, ...) -> blocks outermost, 1 segment
# We derive page_size_bytes = storage.nbytes() // num_blocks, then
# classify dims: any dim whose byte-stride exceeds page_size_bytes
# must be an outer segment dim (e.g. the K/V dim of size 2). A less
# hacky way is to update the interface with the layout.
unique_gpu_caches: dict[str, torch.Tensor] = {}
for name, tensor in seen_ptrs.values():
storage = tensor.untyped_storage()
raw = torch.empty(0, dtype=torch.int8, device=self.device).set_(
storage, 0, (storage.nbytes(),)
)
el = tensor.element_size()
page_size_bytes = storage.nbytes() // num_blocks
outer_dims = [
d for d in range(tensor.ndim) if tensor.stride(d) * el > page_size_bytes
]
if not outer_dims:
unique_gpu_caches[name] = raw.view(num_blocks, -1)
else:
seg_stride = tensor.stride(outer_dims[0]) * el
for idx in range(tensor.shape[outer_dims[0]]):
offset = idx * seg_stride
chunk = raw[offset : offset + seg_stride]
unique_gpu_caches[f"{name}.{idx}"] = chunk.view(num_blocks, -1)
# Compute per-tensor bytes_per_block. Tensors may have different
# page_size_bytes (e.g., UniformTypeKVCacheSpecs with varying head_size).
per_tensor_bpb = [
t.stride(0) * t.element_size() for t in unique_gpu_caches.values()
]
total_bytes_per_block = sum(per_tensor_bpb)
self.num_cpu_blocks = max(1, self.cpu_capacity_bytes // total_bytes_per_block)
logger.info(
"SimpleCPUOffloadWorker: %d unique GPU KV tensors, "
"allocating %d CPU blocks (%.2f GB)",
len(unique_gpu_caches),
self.num_cpu_blocks,
(self.num_cpu_blocks * total_bytes_per_block) / (1024**3),
)
pin_memory = is_pin_memory_available()
if not pin_memory:
logger.warning(
"Pinned memory not available. CPU offload performance may be degraded."
)
self.gpu_kv_caches = unique_gpu_caches
self.cpu_kv_caches = {}
for name, gpu_tensor in unique_gpu_caches.items():
cpu_shape = (self.num_cpu_blocks,) + gpu_tensor.shape[1:]
# Allocate non-pinned first, then pin via cudaHostRegister to
# bypass PyTorch's CUDACachingHostAllocator which rounds up to
# the next power of 2 (e.g. 100 GB -> 128 GB).
tensor = torch.zeros(cpu_shape, dtype=gpu_tensor.dtype, device="cpu")
if pin_memory:
pin_tensor(tensor)
self.cpu_kv_caches[name] = tensor
# Use lowest priority so KV cache I/O yields to compute streams.
low_pri, _ = torch.cuda.Stream.priority_range()
self.load_stream = torch.cuda.Stream(priority=low_pri)
self.store_stream = torch.cuda.Stream(priority=low_pri)
# Initialize copy backend with caches and streams.
self._backend.init(
self.gpu_kv_caches,
self.cpu_kv_caches,
self.device,
self.load_stream,
self.store_stream,
)
def bind_connector_metadata(self, metadata: SimpleCPUOffloadMetadata) -> None:
self._connector_metadata = metadata
if metadata.load_event >= 0:
self._pending_load_event_indices.add(metadata.load_event)
if metadata.store_event >= 0:
self._pending_store_event_indices.add(metadata.store_event)
def clear_connector_metadata(self) -> None:
self._connector_metadata = None
def start_load_kv(self) -> None:
# NOTE: we defer launching both load and store to get_finished(),
# which runs after model execution. This hides the CPU-side
# block copy op overhead (~5ms) behind GPU compute.
pass
def wait_for_save(self) -> None:
pass
def get_finished(
self,
finished_req_ids: set[str],
) -> tuple[set[str] | None, set[str] | None]:
"""Submit transfers and report completed events to the scheduler.
Called after model execution. The manager only schedules stores for
blocks whose KV data is confirmed computed, so we launch both loads
and stores immediately — no deferral or cross-stream sync needed.
Returns:
tuple of (finished_sending, finished_recving).
- finished_sending: always None (stores use worker metadata).
- finished_recving: req_ids whose loads have completed.
"""
# (1) Submit transfers
metadata = self._connector_metadata
if metadata is not None:
# Launch loads (CPU->GPU).
if metadata.load_cpu_blocks:
self._backend.launch_copy(
metadata.load_cpu_blocks,
metadata.load_gpu_blocks,
is_store=False,
event_idx=metadata.load_event,
events_list=self._load_events,
)
# Launch stores (GPU->CPU).
if metadata.store_gpu_blocks:
self._backend.launch_copy(
metadata.store_gpu_blocks,
metadata.store_cpu_blocks,
is_store=True,
event_idx=metadata.store_event,
events_list=self._store_events,
)
# (2) Track completed transfer events
finished_recving: set[str] = set()
if self._pending_load_event_indices:
load_wm = self._poll_stream_events(is_store=False)
for j in [j for j in self._pending_load_event_indices if j <= load_wm]:
self._pending_load_event_indices.discard(j)
req_ids = (
metadata.load_event_to_reqs.get(j) if metadata is not None else None
)
if req_ids:
finished_recving.update(req_ids)
if self._pending_store_event_indices:
store_wm = self._poll_stream_events(is_store=True)
for j in [j for j in self._pending_store_event_indices if j <= store_wm]:
self._pending_store_event_indices.discard(j)
self._completed_store_events[j] = 1
return None, finished_recving or None
def build_connector_worker_meta(self) -> SimpleCPUOffloadWorkerMetadata | None:
"""Return completed store events since the last call."""
if not self._completed_store_events:
return None
meta = SimpleCPUOffloadWorkerMetadata(
completed_store_events=self._completed_store_events,
)
self._completed_store_events = {}
return meta
def handle_preemptions(
self, kv_connector_metadata: SimpleCPUOffloadMetadata
) -> None:
"""Sync all in-flight transfers before preempted blocks are reused."""
if not kv_connector_metadata.need_flush:
return
self._flush_and_sync_all()
def _flush_and_sync_all(self) -> None:
"""Synchronize all in-flight transfer events."""
for event_idx, event in self._load_events:
event.synchronize()
self._load_hwm = event_idx
self._load_events.clear()
for event_idx, event in self._store_events:
event.synchronize()
self._store_hwm = event_idx
self._store_events.clear()
def _poll_stream_events(self, is_store: bool) -> int:
"""Non-blocking poll for completed events and return the high-water mark."""
events = self._store_events if is_store else self._load_events
hwm = self._store_hwm if is_store else self._load_hwm
while events:
event_idx, event = events[0]
if not event.query():
break
hwm = event_idx
events.pop(0)
if is_store:
self._store_hwm = hwm
else:
self._load_hwm = hwm
return hwm

View File

@@ -818,7 +818,6 @@ class SpecDecodeBaseProposer:
def prepare_next_token_ids_padded(
self,
seq_lens_cpu: torch.Tensor,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
@@ -833,7 +832,7 @@ class SpecDecodeBaseProposer:
"""
# Precompute get_token_id for when there is no valid next token
num_reqs = gpu_input_batch.num_reqs
seq_lens_list = seq_lens_cpu[:num_reqs].tolist()
seq_lens_list = (gpu_input_batch.num_tokens_no_spec[:num_reqs] - 1).tolist()
self.backup_next_token_ids.np[:num_reqs] = np.array(
[
requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i])

View File

@@ -286,7 +286,6 @@ class ExtractHiddenStatesProposer:
def prepare_next_token_ids_padded(
self,
seq_lens: torch.Tensor,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
@@ -303,7 +302,7 @@ class ExtractHiddenStatesProposer:
device = sampled_token_ids.device
# Compute backup tokens for discarded / invalid requests
seq_lens_list = seq_lens[:num_reqs].tolist()
seq_lens_list = (gpu_input_batch.num_tokens_no_spec[:num_reqs] - 1).tolist()
backup_tokens_gpu = torch.tensor(
[
requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i])

View File

@@ -108,6 +108,15 @@ class CPUWorker(Worker):
if ret:
logger.info(ret)
# After the thread binding, changing thread num is not allowed
def skip_set_num_threads(x: int):
logger.warning(
"CPU backend doesn't allow to use "
"`torch.set_num_threads` after the thread binding, skip it."
)
torch.set_num_threads = skip_set_num_threads
# Note: unique identifier for creating allreduce shared memory
os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(":")[-1]
# Initialize the distributed environment.

View File

@@ -208,7 +208,7 @@ from .utils import (
if TYPE_CHECKING:
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.encoder_cudagraph import EncoderCudaGraphManager
from vllm.v1.worker.gpu.mm.encoder_cudagraph import EncoderCudaGraphManager
logger = init_logger(__name__)
@@ -1933,9 +1933,24 @@ class GPUModelRunner(
# _update_states_after_model_execute for hybrid models).
if self.num_accepted_tokens_event is not None:
self.num_accepted_tokens_event.synchronize()
self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
)
# Async mode: condense() reordered indices, use prev_positions mapping
if self.use_async_scheduling and prev_req_id_to_index:
prev_idx = self.prev_positions.np[:num_reqs]
new_mask = prev_idx < 0
self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[
np.where(new_mask, 0, prev_idx)
]
)
self.num_accepted_tokens.np[:num_reqs][new_mask] = 1
self.input_batch.num_accepted_tokens_cpu[:num_reqs] = (
self.num_accepted_tokens.np[:num_reqs]
)
else:
# Non-async mode: use values directly
self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
)
self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu()
else:
@@ -4211,7 +4226,6 @@ class GPUModelRunner(
assert spec_decode_common_attn_metadata is not None
next_token_ids, valid_sampled_tokens_count = (
self.drafter.prepare_next_token_ids_padded(
self.optimistic_seq_lens_cpu,
sampled_token_ids,
self.requests,
self.input_batch,
@@ -4578,7 +4592,6 @@ class GPUModelRunner(
)
next_token_ids, valid_sampled_tokens_count = (
self.drafter.prepare_next_token_ids_padded(
self.optimistic_seq_lens_cpu,
sampled_token_ids,
self.requests,
self.input_batch,
@@ -4617,7 +4630,6 @@ class GPUModelRunner(
)
next_token_ids, valid_sampled_tokens_count = (
self.drafter.prepare_next_token_ids_padded(
self.optimistic_seq_lens_cpu,
sampled_token_ids,
self.requests,
self.input_batch,
@@ -5969,7 +5981,9 @@ class GPUModelRunner(
SupportsEncoderCudaGraph,
supports_encoder_cudagraph,
)
from vllm.v1.worker.encoder_cudagraph import EncoderCudaGraphManager
from vllm.v1.worker.gpu.mm.encoder_cudagraph import (
EncoderCudaGraphManager,
)
raw_model = self.get_model()
if supports_encoder_cudagraph(raw_model):