Compare commits
13 Commits
v0.19.1rc0
...
v0.19.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2a69949bda | ||
|
|
8adcf8c40a | ||
|
|
cfad6a509c | ||
|
|
c284a6671c | ||
|
|
3a30a1a6a8 | ||
|
|
29982d48b3 | ||
|
|
1dbbafd3f3 | ||
|
|
0ee3b7fc3d | ||
|
|
268bed9cf3 | ||
|
|
bcc0fdd0f3 | ||
|
|
69b8bd4b33 | ||
|
|
12449f9492 | ||
|
|
b92312dfd7 |
@@ -1,9 +1,10 @@
|
||||
#!/bin/bash
|
||||
set -euox pipefail
|
||||
export VLLM_CPU_CI_ENV=0
|
||||
export VLLM_CPU_KVCACHE_SPACE=1 # avoid OOM
|
||||
|
||||
echo "--- PP+TP"
|
||||
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 &
|
||||
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 --max-model-len=4096 &
|
||||
server_pid=$!
|
||||
timeout 600 bash -c "until curl localhost:8000/v1/models > /dev/null 2>&1; do sleep 1; done" || exit 1
|
||||
vllm bench serve \
|
||||
@@ -23,7 +24,7 @@ if [ "$failed_req" -ne 0 ]; then
|
||||
fi
|
||||
|
||||
echo "--- DP+TP"
|
||||
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 &
|
||||
vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 --max-model-len=4096 &
|
||||
server_pid=$!
|
||||
timeout 600 bash -c "until curl localhost:8000/v1/models > /dev/null 2>&1; do sleep 1; done" || exit 1
|
||||
vllm bench serve \
|
||||
|
||||
@@ -244,12 +244,12 @@ response = client.chat.completions.create(
|
||||
|
||||
Some models, such as [Qwen3](https://qwen.readthedocs.io/en/latest/getting_started/quickstart.html#thinking-budget), [DeepSeek](https://www.alibabacloud.com/help/en/model-studio/deep-thinking), and [Nemotron3](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16), support a thinking budget that limits the maximum number of tokens used for reasoning.
|
||||
|
||||
Token counting starts from `think_start_str`. Once the reasoning token count reaches the configured `thinking_token_budget`, vLLM forces the model to produce `think_end_str`, effectively terminating the reasoning block.
|
||||
Token counting starts from `reasoning_start_str`. Once the reasoning token count reaches the configured `thinking_token_budget`, vLLM forces the model to produce `reasoning_end_str`, effectively terminating the reasoning block.
|
||||
|
||||
To use this feature:
|
||||
|
||||
- `--reasoning-parser` enables reasoning extraction.
|
||||
- `--reasoning-config` defines the reasoning boundary tokens (e.g., `think_start_str`, `think_end_str`).
|
||||
- `--reasoning-config` defines the reasoning boundary tokens (e.g., `reasoning_start_str`, `reasoning_end_str`).
|
||||
- `thinking_token_budget` (a sampling parameter) sets the per-request reasoning token limit.
|
||||
|
||||
If `thinking_token_budget` is not specified, no explicit reasoning limit is applied beyond normal generation constraints such as `max_tokens`.
|
||||
@@ -257,20 +257,20 @@ If `thinking_token_budget` is not specified, no explicit reasoning limit is appl
|
||||
`--reasoning-config` accepts a JSON object corresponding to
|
||||
[ReasoningConfig][vllm.config.ReasoningConfig] with the following fields:
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------------------|----------------|--------------------------------------------------|
|
||||
| `think_start_str` | `str \| null` | String that marks the start of reasoning content |
|
||||
| `think_end_str` | `str \| null` | String that marks the end of reasoning content |
|
||||
| Field | Type | Description |
|
||||
|-----------------------|----------------|--------------------------------------------------|
|
||||
| `reasoning_start_str` | `str \| null` | String that marks the start of reasoning content |
|
||||
| `reasoning_end_str` | `str \| null` | String that marks the end of reasoning content |
|
||||
|
||||
!!! note
|
||||
`think_end_str` can include a transition phrase before the think end token. For example, setting `think_end_str` to `"I have to give the solution based on the thinking directly now.</think>"` instructs the model to emit that phrase when the budget is exhausted, making the reasoning termination more natural.
|
||||
`reasoning_end_str` can include a transition phrase before the reasoning end token. For example, setting `reasoning_end_str` to `"I have to give the solution based on the reasoning directly now.</think>"` instructs the model to emit that phrase when the budget is exhausted, making the reasoning termination more natural.
|
||||
|
||||
### Online Serving
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen3-0.6B \
|
||||
--reasoning-parser qwen3 \
|
||||
--reasoning-config '{"think_start_str": "<think>", "think_end_str": "I have to give the solution based on the thinking directly now.</think>"}'
|
||||
--reasoning-config '{"reasoning_start_str": "<think>", "reasoning_end_str": "I have to give the solution based on the reasoning directly now.</think>"}'
|
||||
```
|
||||
|
||||
Then make a request with `thinking_token_budget` to limit the reasoning tokens:
|
||||
@@ -298,8 +298,8 @@ from vllm.config import ReasoningConfig
|
||||
llm = LLM(
|
||||
model="Qwen/Qwen3-0.6B",
|
||||
reasoning_config=ReasoningConfig(
|
||||
think_start_str="<think>",
|
||||
think_end_str="I have to give the solution based on the thinking directly now.</think>",
|
||||
reasoning_start_str="<think>",
|
||||
reasoning_end_str="I have to give the solution based on the thinking directly now.</think>",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -7,3 +7,4 @@ server_args: >-
|
||||
--max-model-len 4096
|
||||
--data-parallel-size 2
|
||||
--enable-expert-parallel
|
||||
--max-num-seqs 512
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import types
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -11,6 +9,8 @@ from vllm.model_executor.models.bert import (
|
||||
BertMLMHead,
|
||||
SPLADESparsePooler,
|
||||
)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Functional test: SPLADE formula correctness (no HF download needed)
|
||||
@@ -38,8 +38,12 @@ def test_splade_pooler_matches_reference_formula(B, T, H, V):
|
||||
],
|
||||
dtype=torch.long,
|
||||
)
|
||||
meta = types.SimpleNamespace(
|
||||
prompt_lens=prompt_lens_tenser, prompt_token_ids=token_ids
|
||||
meta = PoolingMetadata(
|
||||
prompt_lens=prompt_lens_tenser,
|
||||
prompt_token_ids=token_ids,
|
||||
prompt_token_ids_cpu=token_ids,
|
||||
pooling_params=[PoolingParams(task="embed")] * B,
|
||||
pooling_states=[PoolingStates() for _ in range(B)],
|
||||
)
|
||||
|
||||
# MLM head (prefer BertMLMHead, fallback to Linear if unavailable)
|
||||
|
||||
@@ -394,6 +394,22 @@ VLM_TEST_SETTINGS = {
|
||||
vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}},
|
||||
patch_hf_runner=model_utils.gemma3_patch_hf_runner,
|
||||
),
|
||||
"gemma4": VLMTestInfo(
|
||||
models=["google/gemma-4-E2B-it"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n", # noqa: E501
|
||||
single_image_prompts=IMAGE_ASSETS.prompts(
|
||||
{
|
||||
"stop_sign": "What's the content in the center of the image?",
|
||||
"cherry_blossom": "What is the season?",
|
||||
}
|
||||
),
|
||||
multi_image_prompt="Describe the two images in detail.",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
vllm_runner_kwargs={"limit_mm_per_prompt": {"image": 4}},
|
||||
),
|
||||
"granite_vision": VLMTestInfo(
|
||||
models=["ibm-granite/granite-vision-3.3-2b"],
|
||||
test_type=(VLMTestType.IMAGE),
|
||||
|
||||
@@ -15,6 +15,10 @@ from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
|
||||
MODEL_NAME = "ModernVBERT/colmodernvbert-merged"
|
||||
COLBERT_DIM = 128
|
||||
DTYPE = "half"
|
||||
# Fixme:
|
||||
# Update colmodernvbert code to support the latest HF version
|
||||
# and remove revision set.
|
||||
REVISION = "4a0a9f3ac7a7992fec410bfa8e3d080ac9a5bcee"
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
@@ -26,6 +30,7 @@ def test_colmodernvbert_text_token_embed(vllm_runner):
|
||||
"""Text query produces per-token embeddings with shape (seq_len, 128)."""
|
||||
with vllm_runner(
|
||||
MODEL_NAME,
|
||||
revision=REVISION,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
enforce_eager=True,
|
||||
@@ -49,6 +54,7 @@ def test_colmodernvbert_text_relevance_ordering(vllm_runner):
|
||||
|
||||
with vllm_runner(
|
||||
MODEL_NAME,
|
||||
revision=REVISION,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
enforce_eager=True,
|
||||
@@ -66,6 +72,7 @@ def test_colmodernvbert_text_late_interaction(vllm_runner):
|
||||
|
||||
with vllm_runner(
|
||||
MODEL_NAME,
|
||||
revision=REVISION,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
enforce_eager=True,
|
||||
@@ -92,6 +99,7 @@ def test_colmodernvbert_image_token_embed(vllm_runner, image_assets):
|
||||
"""Image input produces per-token embeddings including vision tokens."""
|
||||
with vllm_runner(
|
||||
MODEL_NAME,
|
||||
revision=REVISION,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
enforce_eager=True,
|
||||
|
||||
44
tests/models/multimodal/processing/test_gemma4.py
Normal file
44
tests/models/multimodal/processing/test_gemma4.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
from ....conftest import ImageTestAssets
|
||||
from ...utils import build_model_context
|
||||
|
||||
# TODO: to be updated to "google/gemma-4-e2b-it" once the models are available
|
||||
GEMMA4_MODEL_ID = "google/gemma-4-E2B-it"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", [GEMMA4_MODEL_ID])
|
||||
def test_limit_mm_per_prompt(
|
||||
image_assets: ImageTestAssets,
|
||||
model_id: str,
|
||||
):
|
||||
"""Test that limit_mm_per_prompt accurately restricts multiple images."""
|
||||
# We only allow 1 image
|
||||
ctx = build_model_context(
|
||||
model_id,
|
||||
mm_processor_kwargs={},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
|
||||
# Provide 2 images in the prompt
|
||||
prompt = "<image><image>"
|
||||
# image_assets usually has multiple images
|
||||
images = [asset.pil_image for asset in image_assets][:2]
|
||||
if len(images) < 2:
|
||||
images = [images[0], images[0]]
|
||||
|
||||
mm_data = {"image": images}
|
||||
|
||||
# Expect ValueError when exceeding limit
|
||||
with pytest.raises(ValueError, match="At most 1 image"):
|
||||
processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
@@ -277,6 +277,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"google/gemma-2-9b", extras={"tiny": "google/gemma-2-2b-it"}
|
||||
),
|
||||
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
|
||||
"Gemma4ForCausalLM": _HfExamplesInfo(
|
||||
"google/gemma-4-E2B-it",
|
||||
min_transformers_version="5.0.0",
|
||||
),
|
||||
"Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it"),
|
||||
"GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"),
|
||||
"Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"),
|
||||
@@ -636,6 +640,7 @@ _LATE_INTERACTION_EXAMPLE_MODELS = {
|
||||
# [Multimodal]
|
||||
"ColModernVBertForRetrieval": _HfExamplesInfo(
|
||||
"ModernVBERT/colmodernvbert-merged",
|
||||
revision="4a0a9f3ac7a7992fec410bfa8e3d080ac9a5bcee",
|
||||
),
|
||||
"ColPaliForRetrieval": _HfExamplesInfo("vidore/colpali-v1.3-hf"),
|
||||
"ColQwen3": _HfExamplesInfo(
|
||||
@@ -804,6 +809,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
),
|
||||
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
|
||||
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
|
||||
"Gemma4ForConditionalGeneration": _HfExamplesInfo(
|
||||
"google/gemma-4-E2B-it",
|
||||
min_transformers_version="5.5.0",
|
||||
),
|
||||
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it"),
|
||||
"GlmAsrForConditionalGeneration": _HfExamplesInfo(
|
||||
"zai-org/GLM-ASR-Nano-2512",
|
||||
|
||||
@@ -239,6 +239,17 @@ def test_video_media_io_backend_env_var_fallback(monkeypatch: pytest.MonkeyPatch
|
||||
assert metadata_missing["video_backend"] == "test_video_backend_override_2"
|
||||
|
||||
|
||||
def _make_jpeg_b64_frames(n: int, width: int = 8, height: int = 8) -> list[str]:
|
||||
"""Return *n* tiny base64-encoded JPEG frames."""
|
||||
frames: list[str] = []
|
||||
for i in range(n):
|
||||
img = Image.new("RGB", (width, height), color=(i % 256, 0, 0))
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG")
|
||||
frames.append(pybase64.b64encode(buf.getvalue()).decode("ascii"))
|
||||
return frames
|
||||
|
||||
|
||||
def test_load_base64_jpeg_returns_metadata():
|
||||
"""Regression test: load_base64 with video/jpeg must return metadata.
|
||||
|
||||
@@ -248,16 +259,8 @@ def test_load_base64_jpeg_returns_metadata():
|
||||
"""
|
||||
|
||||
num_test_frames = 3
|
||||
frame_width, frame_height = 8, 8
|
||||
|
||||
# Build a few tiny JPEG frames and base64-encode them
|
||||
b64_frames = []
|
||||
for i in range(num_test_frames):
|
||||
img = Image.new("RGB", (frame_width, frame_height), color=(i * 80, 0, 0))
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG")
|
||||
b64_frames.append(pybase64.b64encode(buf.getvalue()).decode("ascii"))
|
||||
|
||||
b64_frames = _make_jpeg_b64_frames(num_test_frames)
|
||||
data = ",".join(b64_frames)
|
||||
|
||||
imageio = ImageMediaIO()
|
||||
@@ -287,3 +290,52 @@ def test_load_base64_jpeg_returns_metadata():
|
||||
# Default fps=1 → duration == num_frames
|
||||
assert metadata["fps"] == 1.0
|
||||
assert metadata["duration"] == float(num_test_frames)
|
||||
|
||||
|
||||
def test_load_base64_jpeg_enforces_num_frames_limit():
|
||||
"""Frames beyond num_frames must be truncated in the video/jpeg path.
|
||||
|
||||
Without the limit an attacker can send thousands of base64 JPEG frames
|
||||
in a single request and exhaust server memory (OOM).
|
||||
"""
|
||||
num_frames_limit = 4
|
||||
sent_frames = 20
|
||||
|
||||
b64_frames = _make_jpeg_b64_frames(sent_frames)
|
||||
data = ",".join(b64_frames)
|
||||
|
||||
imageio = ImageMediaIO()
|
||||
videoio = VideoMediaIO(imageio, num_frames=num_frames_limit)
|
||||
frames, metadata = videoio.load_base64("video/jpeg", data)
|
||||
|
||||
assert frames.shape[0] == num_frames_limit
|
||||
assert metadata["total_num_frames"] == num_frames_limit
|
||||
assert metadata["frames_indices"] == list(range(num_frames_limit))
|
||||
|
||||
|
||||
def test_load_base64_jpeg_no_limit_when_num_frames_negative():
|
||||
"""When num_frames is -1, all frames should be loaded without truncation."""
|
||||
sent_frames = 10
|
||||
|
||||
b64_frames = _make_jpeg_b64_frames(sent_frames)
|
||||
data = ",".join(b64_frames)
|
||||
|
||||
imageio = ImageMediaIO()
|
||||
videoio = VideoMediaIO(imageio, num_frames=-1)
|
||||
frames, metadata = videoio.load_base64("video/jpeg", data)
|
||||
|
||||
assert frames.shape[0] == sent_frames
|
||||
assert metadata["total_num_frames"] == sent_frames
|
||||
assert metadata["frames_indices"] == list(range(sent_frames))
|
||||
|
||||
|
||||
def test_load_base64_jpeg_raises_on_zero_num_frames():
|
||||
"""num_frames=0 is invalid and should raise ValueError."""
|
||||
b64_frames = _make_jpeg_b64_frames(3)
|
||||
data = ",".join(b64_frames)
|
||||
|
||||
imageio = ImageMediaIO()
|
||||
videoio = VideoMediaIO(imageio, num_frames=0)
|
||||
|
||||
with pytest.raises(ValueError, match="num_frames must be greater than 0 or -1"):
|
||||
videoio.load_base64("video/jpeg", data)
|
||||
|
||||
196
tests/reasoning/test_gemma4_reasoning_parser.py
Normal file
196
tests/reasoning/test_gemma4_reasoning_parser.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.reasoning.utils import run_reasoning_extraction
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
# Using mistral tokenizer as a generic mock since the actual model is not on HF
|
||||
from vllm.tokenizers.registry import get_tokenizer
|
||||
|
||||
parser_name = "gemma4"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def generic_tokenizer():
|
||||
return get_tokenizer("google/gemma-4-E2B-it")
|
||||
|
||||
|
||||
INVALID_SIMPLE_NONSTREAMING = {
|
||||
"output": "This is a reasoning section<channel|>This is the rest",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
INVALID_SIMPLE_STREAMING = {
|
||||
"output": "This is a reasoning section<channel|>This is the rest",
|
||||
"reasoning": None,
|
||||
"content": "This is a reasoning sectionThis is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
INVALID_COMPLETE_NONSTREAMING = {
|
||||
"output": "This is a reasoning section<channel|>",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
INVALID_COMPLETE_STREAMING = {
|
||||
"output": "This is a reasoning section<channel|>",
|
||||
"reasoning": None,
|
||||
"content": "This is a reasoning section",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
NO_CONTENT = {
|
||||
"output": "<|channel>This is reasoning",
|
||||
"reasoning": "This is reasoning",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
NO_REASONING = {
|
||||
"output": "This is content",
|
||||
"reasoning": None,
|
||||
"content": "This is content",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
REASONING_WITH_CHANNEL = {
|
||||
"output": "<|channel>This is a reasoning section<channel|>This is the rest",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
COMPLETE_REASONING_WITH_CHANNEL = {
|
||||
"output": "<|channel>This is a reasoning section<channel|>",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
MULTIPLE_LINES_WITH_CHANNEL = {
|
||||
"output": "<|channel>This\nThat<channel|>This is the rest\nThat",
|
||||
"reasoning": "This\nThat",
|
||||
"content": "This is the rest\nThat",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
CHANNEL_NO_END = {
|
||||
"output": "<|channel>This is a reasoning section",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
EMPTY = {
|
||||
"output": "",
|
||||
"reasoning": None,
|
||||
"content": "",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
NEW_LINE_NONSTREAMING = {
|
||||
"output": (
|
||||
"Before\n<|channel>This is a reasoning section<channel|>\nThis is the rest"
|
||||
),
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": "\nThis is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
NEW_LINE_STREAMING = {
|
||||
"output": (
|
||||
"Before\n<|channel>This is a reasoning section<channel|>\nThis is the rest"
|
||||
),
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": "Before\n\nThis is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(False, INVALID_SIMPLE_NONSTREAMING, id="invalid_simple"),
|
||||
pytest.param(True, INVALID_SIMPLE_STREAMING, id="invalid_simple_streaming"),
|
||||
pytest.param(False, INVALID_COMPLETE_NONSTREAMING, id="invalid_complete"),
|
||||
pytest.param(True, INVALID_COMPLETE_STREAMING, id="invalid_complete_streaming"),
|
||||
pytest.param(False, NO_CONTENT, id="no_content"),
|
||||
pytest.param(False, NO_REASONING, id="no_reasoning"),
|
||||
pytest.param(False, REASONING_WITH_CHANNEL, id="reasoning"),
|
||||
pytest.param(True, REASONING_WITH_CHANNEL, id="reasoning_streaming"),
|
||||
pytest.param(False, COMPLETE_REASONING_WITH_CHANNEL, id="complete_reasoning"),
|
||||
pytest.param(
|
||||
True, COMPLETE_REASONING_WITH_CHANNEL, id="complete_reasoning_streaming"
|
||||
),
|
||||
pytest.param(False, MULTIPLE_LINES_WITH_CHANNEL, id="multiple_lines"),
|
||||
pytest.param(True, MULTIPLE_LINES_WITH_CHANNEL, id="multiple_lines_streaming"),
|
||||
pytest.param(False, CHANNEL_NO_END, id="no_end"),
|
||||
pytest.param(True, CHANNEL_NO_END, id="no_end_streaming"),
|
||||
pytest.param(False, EMPTY, id="empty"),
|
||||
pytest.param(False, NEW_LINE_NONSTREAMING, id="new_line"),
|
||||
pytest.param(True, NEW_LINE_STREAMING, id="new_line_streaming"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
|
||||
def test_gemma4_reasoning(
|
||||
streaming: bool,
|
||||
param_dict: dict,
|
||||
generic_tokenizer,
|
||||
):
|
||||
output = param_dict["output"]
|
||||
|
||||
# Resolve token IDs dynamically from the real tokenizer
|
||||
vocab = generic_tokenizer.get_vocab()
|
||||
start_token_id = vocab["<|channel>"]
|
||||
end_token_id = vocab["<channel|>"]
|
||||
|
||||
index_start = output.find("<|channel>")
|
||||
len_start = len("<|channel>")
|
||||
index_end = output.find("<channel|>")
|
||||
len_end = len("<channel|>")
|
||||
|
||||
output_tokens = []
|
||||
|
||||
def _encode(text: str) -> list[int]:
|
||||
if not text:
|
||||
return []
|
||||
# Handle both raw transformers and vLLM wrappers
|
||||
enc = getattr(generic_tokenizer, "tokenizer", generic_tokenizer)
|
||||
try:
|
||||
return enc.encode(text, add_special_tokens=False)
|
||||
except TypeError:
|
||||
return enc.encode(text)
|
||||
|
||||
if index_start != -1:
|
||||
output_before = output[:index_start]
|
||||
output_tokens += _encode(output_before)
|
||||
output_tokens += [start_token_id]
|
||||
|
||||
if index_end != -1:
|
||||
output_middle = output[index_start + len_start : index_end]
|
||||
output_after = output[index_end + len_end :]
|
||||
output_tokens += _encode(output_middle)
|
||||
output_tokens += [end_token_id]
|
||||
output_tokens += _encode(output_after)
|
||||
else:
|
||||
output_middle = output[index_start + len_start :]
|
||||
output_tokens += _encode(output_middle)
|
||||
elif index_end != -1:
|
||||
output_before = output[:index_end]
|
||||
output_after = output[index_end + len_end :]
|
||||
output_tokens += _encode(output_before)
|
||||
output_tokens += [end_token_id]
|
||||
output_tokens += _encode(output_after)
|
||||
else:
|
||||
output_tokens += _encode(output)
|
||||
|
||||
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
|
||||
generic_tokenizer
|
||||
)
|
||||
|
||||
# We use the generic run_reasoning_extraction from utils
|
||||
# Use decode per token to get standard spaces instead of
|
||||
# SentencePiece space characters
|
||||
output_token_strings = [generic_tokenizer.decode([t]) for t in output_tokens]
|
||||
reasoning, content = run_reasoning_extraction(
|
||||
parser, output_token_strings, streaming=streaming
|
||||
)
|
||||
|
||||
assert reasoning == param_dict["reasoning"]
|
||||
assert content == param_dict["content"]
|
||||
|
||||
# Test is_reasoning_end
|
||||
is_reasoning_end = parser.is_reasoning_end(output_tokens)
|
||||
assert is_reasoning_end == param_dict["is_reasoning_end"]
|
||||
504
tests/tool_parsers/test_gemma4_tool_parser.py
Normal file
504
tests/tool_parsers/test_gemma4_tool_parser.py
Normal file
@@ -0,0 +1,504 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.tool_parsers.gemma4_tool_parser import (
|
||||
TOOL_CALL_END,
|
||||
TOOL_CALL_START,
|
||||
Gemma4ToolParser,
|
||||
_parse_gemma4_args,
|
||||
_parse_gemma4_array,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tokenizer():
|
||||
tokenizer = MagicMock()
|
||||
tokenizer.encode.return_value = [1, 2, 3]
|
||||
# Include the tool call start token in the vocab for the parser
|
||||
tokenizer.get_vocab.return_value = {TOOL_CALL_START: 48, TOOL_CALL_END: 49}
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser(mock_tokenizer):
|
||||
return Gemma4ToolParser(mock_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
request = MagicMock(spec=ChatCompletionRequest)
|
||||
request.tools = []
|
||||
request.tool_choice = "auto"
|
||||
return request
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests for _parse_gemma4_args (shared parser logic)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseGemma4Args:
|
||||
def test_empty_string(self):
|
||||
assert _parse_gemma4_args("") == {}
|
||||
|
||||
def test_whitespace_only(self):
|
||||
assert _parse_gemma4_args(" ") == {}
|
||||
|
||||
def test_single_string_value(self):
|
||||
result = _parse_gemma4_args('location:<|"|>Paris<|"|>')
|
||||
assert result == {"location": "Paris"}
|
||||
|
||||
def test_string_value_with_comma(self):
|
||||
result = _parse_gemma4_args('location:<|"|>Paris, France<|"|>')
|
||||
assert result == {"location": "Paris, France"}
|
||||
|
||||
def test_multiple_string_values(self):
|
||||
result = _parse_gemma4_args(
|
||||
'location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>'
|
||||
)
|
||||
assert result == {"location": "San Francisco", "unit": "celsius"}
|
||||
|
||||
def test_integer_value(self):
|
||||
result = _parse_gemma4_args("count:42")
|
||||
assert result == {"count": 42}
|
||||
|
||||
def test_float_value(self):
|
||||
result = _parse_gemma4_args("score:3.14")
|
||||
assert result == {"score": 3.14}
|
||||
|
||||
def test_boolean_true(self):
|
||||
result = _parse_gemma4_args("flag:true")
|
||||
assert result == {"flag": True}
|
||||
|
||||
def test_boolean_false(self):
|
||||
result = _parse_gemma4_args("flag:false")
|
||||
assert result == {"flag": False}
|
||||
|
||||
def test_mixed_types(self):
|
||||
result = _parse_gemma4_args(
|
||||
'name:<|"|>test<|"|>,count:42,active:true,score:3.14'
|
||||
)
|
||||
assert result == {
|
||||
"name": "test",
|
||||
"count": 42,
|
||||
"active": True,
|
||||
"score": 3.14,
|
||||
}
|
||||
|
||||
def test_nested_object(self):
|
||||
result = _parse_gemma4_args('nested:{inner:<|"|>value<|"|>}')
|
||||
assert result == {"nested": {"inner": "value"}}
|
||||
|
||||
def test_array_of_strings(self):
|
||||
result = _parse_gemma4_args('items:[<|"|>a<|"|>,<|"|>b<|"|>]')
|
||||
assert result == {"items": ["a", "b"]}
|
||||
|
||||
def test_unterminated_string(self):
|
||||
"""Unterminated strings should take everything after the delimiter."""
|
||||
result = _parse_gemma4_args('key:<|"|>unterminated')
|
||||
assert result == {"key": "unterminated"}
|
||||
|
||||
def test_empty_value(self):
|
||||
"""Key with no value after colon."""
|
||||
result = _parse_gemma4_args("key:")
|
||||
assert result == {"key": ""}
|
||||
|
||||
|
||||
class TestParseGemma4Array:
|
||||
def test_string_array(self):
|
||||
result = _parse_gemma4_array('<|"|>a<|"|>,<|"|>b<|"|>')
|
||||
assert result == ["a", "b"]
|
||||
|
||||
def test_empty_array(self):
|
||||
result = _parse_gemma4_array("")
|
||||
assert result == []
|
||||
|
||||
def test_bare_values(self):
|
||||
result = _parse_gemma4_array("42,true,3.14")
|
||||
assert result == [42, True, 3.14]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-streaming extraction tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtractToolCalls:
|
||||
def test_no_tool_calls(self, parser, mock_request):
|
||||
model_output = "Hello, how can I help you today?"
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is False
|
||||
assert result.tool_calls == []
|
||||
assert result.content == model_output
|
||||
|
||||
def test_single_tool_call(self, parser, mock_request):
|
||||
model_output = (
|
||||
'<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|>'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "get_weather"
|
||||
args = json.loads(result.tool_calls[0].function.arguments)
|
||||
assert args == {"location": "London"}
|
||||
|
||||
def test_multiple_arguments(self, parser, mock_request):
|
||||
model_output = (
|
||||
"<|tool_call>call:get_weather{"
|
||||
'location:<|"|>San Francisco<|"|>,'
|
||||
'unit:<|"|>celsius<|"|>}'
|
||||
"<tool_call|>"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "get_weather"
|
||||
args = json.loads(result.tool_calls[0].function.arguments)
|
||||
assert args == {"location": "San Francisco", "unit": "celsius"}
|
||||
|
||||
def test_text_before_tool_call(self, parser, mock_request):
|
||||
model_output = (
|
||||
"Let me check the weather for you. "
|
||||
'<|tool_call>call:get_weather{location:<|"|>Paris<|"|>}'
|
||||
"<tool_call|>"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert result.content == "Let me check the weather for you."
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "get_weather"
|
||||
|
||||
def test_multiple_tool_calls(self, parser, mock_request):
|
||||
model_output = (
|
||||
'<|tool_call>call:get_weather{location:<|"|>London<|"|>}'
|
||||
"<tool_call|>"
|
||||
'<|tool_call>call:get_time{location:<|"|>London<|"|>}'
|
||||
"<tool_call|>"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 2
|
||||
assert result.tool_calls[0].function.name == "get_weather"
|
||||
assert result.tool_calls[1].function.name == "get_time"
|
||||
|
||||
def test_nested_arguments(self, parser, mock_request):
|
||||
model_output = (
|
||||
"<|tool_call>call:complex_function{"
|
||||
'nested:{inner:<|"|>value<|"|>},'
|
||||
'list:[<|"|>a<|"|>,<|"|>b<|"|>]}'
|
||||
"<tool_call|>"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "complex_function"
|
||||
args = json.loads(result.tool_calls[0].function.arguments)
|
||||
assert args == {"nested": {"inner": "value"}, "list": ["a", "b"]}
|
||||
|
||||
def test_tool_call_with_number_and_boolean(self, parser, mock_request):
|
||||
model_output = (
|
||||
"<|tool_call>call:set_status{"
|
||||
"is_active:true,"
|
||||
"count:42,"
|
||||
"score:3.14}"
|
||||
"<tool_call|>"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "set_status"
|
||||
args = json.loads(result.tool_calls[0].function.arguments)
|
||||
assert args == {"is_active": True, "count": 42, "score": 3.14}
|
||||
|
||||
def test_incomplete_tool_call(self, parser, mock_request):
|
||||
model_output = '<|tool_call>call:get_weather{location:<|"|>London'
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
# Incomplete — no <tool_call|> end marker, regex won't match
|
||||
assert result.tools_called is False
|
||||
assert result.content == model_output
|
||||
|
||||
def test_hyphenated_function_name(self, parser, mock_request):
|
||||
"""Ensure function names with hyphens are parsed correctly."""
|
||||
model_output = (
|
||||
'<|tool_call>call:get-weather{location:<|"|>London<|"|>}<tool_call|>'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert result.tool_calls[0].function.name == "get-weather"
|
||||
|
||||
def test_dotted_function_name(self, parser, mock_request):
|
||||
"""Ensure function names with dots are parsed correctly."""
|
||||
model_output = (
|
||||
'<|tool_call>call:weather.get{location:<|"|>London<|"|>}<tool_call|>'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert result.tool_calls[0].function.name == "weather.get"
|
||||
|
||||
def test_no_arguments(self, parser, mock_request):
|
||||
"""Tool calls with empty arguments."""
|
||||
model_output = "<|tool_call>call:get_status{}<tool_call|>"
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert result.tool_calls[0].function.name == "get_status"
|
||||
args = json.loads(result.tool_calls[0].function.arguments)
|
||||
assert args == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Streaming extraction tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStreamingExtraction:
|
||||
"""Tests for the streaming tool call extraction.
|
||||
|
||||
These simulate the token-by-token streaming that vLLM performs,
|
||||
feeding incremental text to extract_tool_calls_streaming() and
|
||||
verifying that the accumulated argument deltas form valid JSON.
|
||||
"""
|
||||
|
||||
def _simulate_streaming(
|
||||
self, parser: Gemma4ToolParser, mock_request: Any, chunks: list[str]
|
||||
) -> list[tuple[Any, str]]:
|
||||
"""Feed chunks through the streaming parser and collect results.
|
||||
|
||||
Returns a list of (delta_message, accumulated_text) tuples.
|
||||
"""
|
||||
results: list[tuple[Any, str]] = []
|
||||
previous_text: str = ""
|
||||
previous_token_ids: list[int] = []
|
||||
|
||||
for chunk in chunks:
|
||||
current_text = previous_text + chunk
|
||||
# Use token ID 48 for tool_call start, 49 for end, 0 otherwise
|
||||
delta_token_ids: list[int] = []
|
||||
if TOOL_CALL_START in chunk:
|
||||
delta_token_ids.append(48)
|
||||
elif TOOL_CALL_END in chunk:
|
||||
delta_token_ids.append(49)
|
||||
else:
|
||||
delta_token_ids.append(0)
|
||||
|
||||
current_token_ids = previous_token_ids + delta_token_ids
|
||||
|
||||
delta = parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=chunk,
|
||||
previous_token_ids=tuple(previous_token_ids),
|
||||
current_token_ids=tuple(current_token_ids),
|
||||
delta_token_ids=tuple(delta_token_ids),
|
||||
request=mock_request,
|
||||
)
|
||||
results.append((delta, current_text))
|
||||
previous_text = current_text
|
||||
previous_token_ids = list(current_token_ids)
|
||||
|
||||
return results
|
||||
|
||||
def _collect_arguments(self, results):
|
||||
"""Collect all argument deltas from streaming results into one string."""
|
||||
args_text = ""
|
||||
for delta, _ in results:
|
||||
if delta and delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
func = tc.function if isinstance(tc.function, dict) else tc.function
|
||||
if isinstance(func, dict):
|
||||
arg = func.get("arguments", "")
|
||||
else:
|
||||
arg = getattr(func, "arguments", "") or ""
|
||||
if arg:
|
||||
args_text += arg
|
||||
return args_text
|
||||
|
||||
def _collect_function_name(self, results):
|
||||
"""Extract the function name from streaming results."""
|
||||
for delta, _ in results:
|
||||
if delta and delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
func = tc.function if isinstance(tc.function, dict) else tc.function
|
||||
if isinstance(func, dict):
|
||||
name = func.get("name")
|
||||
else:
|
||||
name = getattr(func, "name", None)
|
||||
if name:
|
||||
return name
|
||||
return None
|
||||
|
||||
def test_basic_streaming_single_tool(self, parser, mock_request):
|
||||
"""Simulate the exact streaming scenario from the bug report.
|
||||
|
||||
Model generates:
|
||||
<|tool_call>call:get_weather{location:<|"|>Paris, France<|"|>}<tool_call|>
|
||||
|
||||
Expected: arguments should be valid JSON {"location": "Paris, France"}
|
||||
"""
|
||||
chunks = [
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
'location:<|"|>Paris',
|
||||
", France",
|
||||
'<|"|>}',
|
||||
"<tool_call|>",
|
||||
]
|
||||
|
||||
results = self._simulate_streaming(parser, mock_request, chunks)
|
||||
|
||||
# Verify function name
|
||||
name = self._collect_function_name(results)
|
||||
assert name == "get_weather", f"Expected 'get_weather', got '{name}'"
|
||||
|
||||
# Verify arguments form valid JSON
|
||||
args_text = self._collect_arguments(results)
|
||||
assert args_text, "No arguments were streamed"
|
||||
parsed_args = json.loads(args_text)
|
||||
assert parsed_args == {"location": "Paris, France"}
|
||||
|
||||
def test_streaming_multi_arg(self, parser, mock_request):
|
||||
"""Streaming with multiple arguments."""
|
||||
chunks = [
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
'location:<|"|>Tokyo<|"|>,',
|
||||
'unit:<|"|>celsius<|"|>}',
|
||||
"<tool_call|>",
|
||||
]
|
||||
|
||||
results = self._simulate_streaming(parser, mock_request, chunks)
|
||||
|
||||
name = self._collect_function_name(results)
|
||||
assert name == "get_weather"
|
||||
|
||||
args_text = self._collect_arguments(results)
|
||||
assert args_text
|
||||
parsed_args = json.loads(args_text)
|
||||
assert parsed_args == {"location": "Tokyo", "unit": "celsius"}
|
||||
|
||||
def test_streaming_no_extra_brace(self, parser, mock_request):
|
||||
"""Verify the closing } is NOT leaked into arguments (Bug #2)."""
|
||||
chunks = [
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
'location:<|"|>London<|"|>}',
|
||||
"<tool_call|>",
|
||||
]
|
||||
|
||||
results = self._simulate_streaming(parser, mock_request, chunks)
|
||||
args_text = self._collect_arguments(results)
|
||||
assert args_text
|
||||
|
||||
# The args text must be valid JSON (no extra })
|
||||
parsed = json.loads(args_text)
|
||||
assert parsed == {"location": "London"}
|
||||
|
||||
# Specifically assert no double-brace
|
||||
assert args_text.count("}") <= 1, (
|
||||
f"Arguments contain extra closing brace: {args_text!r}"
|
||||
)
|
||||
|
||||
def test_streaming_no_unquoted_keys(self, parser, mock_request):
|
||||
"""Verify keys are properly quoted in JSON (Bug #1)."""
|
||||
chunks = [
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
'location:<|"|>Paris<|"|>}',
|
||||
"<tool_call|>",
|
||||
]
|
||||
|
||||
results = self._simulate_streaming(parser, mock_request, chunks)
|
||||
args_text = self._collect_arguments(results)
|
||||
|
||||
# Must start with { and contain quoted key
|
||||
assert args_text.lstrip().startswith("{"), (
|
||||
f"Arguments don't start with '{{': {args_text!r}"
|
||||
)
|
||||
assert '"location"' in args_text, (
|
||||
f"Key 'location' not properly quoted: {args_text!r}"
|
||||
)
|
||||
|
||||
def test_streaming_name_no_call_prefix(self, parser, mock_request):
|
||||
"""Verify function name has no 'call:' prefix."""
|
||||
chunks = [
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
'location:<|"|>Paris<|"|>}',
|
||||
"<tool_call|>",
|
||||
]
|
||||
|
||||
results = self._simulate_streaming(parser, mock_request, chunks)
|
||||
name = self._collect_function_name(results)
|
||||
assert name == "get_weather"
|
||||
assert not name.startswith("call:"), f"Name has 'call:' prefix: {name!r}"
|
||||
|
||||
def test_streaming_text_before_tool_call(self, parser, mock_request):
|
||||
"""Text before tool call should be emitted as content."""
|
||||
chunks = [
|
||||
"Let me check ",
|
||||
"the weather. ",
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
'location:<|"|>London<|"|>}',
|
||||
"<tool_call|>",
|
||||
]
|
||||
|
||||
results = self._simulate_streaming(parser, mock_request, chunks)
|
||||
|
||||
# First chunks should be content
|
||||
content_parts = []
|
||||
for delta, _ in results:
|
||||
if delta and delta.content:
|
||||
content_parts.append(delta.content)
|
||||
|
||||
assert "".join(content_parts).strip().startswith("Let me check")
|
||||
|
||||
def test_streaming_numeric_args(self, parser, mock_request):
|
||||
"""Streaming with numeric and boolean argument values."""
|
||||
chunks = [
|
||||
"<|tool_call>",
|
||||
"call:set_config{",
|
||||
"count:42,",
|
||||
"active:true}",
|
||||
"<tool_call|>",
|
||||
]
|
||||
|
||||
results = self._simulate_streaming(parser, mock_request, chunks)
|
||||
args_text = self._collect_arguments(results)
|
||||
if args_text:
|
||||
parsed_args = json.loads(args_text)
|
||||
assert parsed_args["count"] == 42
|
||||
assert parsed_args["active"] is True
|
||||
|
||||
def test_streaming_empty_args(self, parser, mock_request):
|
||||
"""Tool call with no arguments."""
|
||||
chunks = [
|
||||
"<|tool_call>",
|
||||
"call:get_status{}",
|
||||
"<tool_call|>",
|
||||
]
|
||||
|
||||
results = self._simulate_streaming(parser, mock_request, chunks)
|
||||
name = self._collect_function_name(results)
|
||||
assert name == "get_status"
|
||||
@@ -42,6 +42,7 @@ from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
FlashMLASparseBackend,
|
||||
triton_convert_req_index_to_global_index,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.indexer import split_indexer_prefill_chunks
|
||||
from vllm.v1.attention.backends.utils import split_prefill_chunks
|
||||
from vllm.v1.attention.ops import flashmla
|
||||
|
||||
@@ -716,6 +717,81 @@ def test_split_prefill_chunks(seq_lens, max_buf, expected):
|
||||
assert out == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens,query_lens,workspace_size,max_logits_bytes,expected",
|
||||
[
|
||||
# Logits constraint triggers split (M*N exceeds budget)
|
||||
# req0: M=10, N=100 -> 1000 elems (4000 bytes) - fits in 5000
|
||||
# req1: adding M=10, N=100 -> new_M=20, new_N=200 -> 4000 elems > 1250
|
||||
(
|
||||
torch.tensor([100, 100, 100]),
|
||||
torch.tensor([10, 10, 10]),
|
||||
1000, # workspace allows all
|
||||
5000, # 1250 float32 elems -> forces split
|
||||
[
|
||||
(slice(0, 1), slice(0, 10)),
|
||||
(slice(1, 2), slice(0, 10)),
|
||||
(slice(2, 3), slice(0, 10)),
|
||||
],
|
||||
),
|
||||
# Both constraints satisfied - all fit in one chunk
|
||||
(
|
||||
torch.tensor([10, 10, 10]),
|
||||
torch.tensor([5, 5, 5]),
|
||||
100,
|
||||
10000, # 2500 elems, M*N = 15*30 = 450 < 2500
|
||||
[(slice(0, 3), slice(0, 15))],
|
||||
),
|
||||
# Workspace constraint triggers first
|
||||
(
|
||||
torch.tensor([50, 50, 50]),
|
||||
torch.tensor([1, 1, 1]),
|
||||
50, # workspace only fits one at a time
|
||||
1000000, # logits budget is huge
|
||||
[
|
||||
(slice(0, 1), slice(0, 1)),
|
||||
(slice(1, 2), slice(0, 1)),
|
||||
(slice(2, 3), slice(0, 1)),
|
||||
],
|
||||
),
|
||||
# Greedy filling: first two fit, third doesn't
|
||||
# req0: M=5, N=10 -> 50 elems
|
||||
# req0+1: M=10, N=20 -> 200 elems <= 250
|
||||
# req0+1+2: M=15, N=30 -> 450 elems > 250
|
||||
(
|
||||
torch.tensor([10, 10, 10]),
|
||||
torch.tensor([5, 5, 5]),
|
||||
100,
|
||||
1000, # 250 elems
|
||||
[(slice(0, 2), slice(0, 10)), (slice(2, 3), slice(0, 5))],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_split_indexer_prefill_chunks(
|
||||
seq_lens, query_lens, workspace_size, max_logits_bytes, expected
|
||||
):
|
||||
out = split_indexer_prefill_chunks(
|
||||
seq_lens,
|
||||
query_lens,
|
||||
workspace_size,
|
||||
max_logits_bytes,
|
||||
)
|
||||
assert out == expected
|
||||
|
||||
|
||||
def test_split_indexer_prefill_chunks_single_request_overflow():
|
||||
"""Test that single request exceeding budget is sub-chunked on query dim."""
|
||||
seq_lens = torch.tensor([1000, 50])
|
||||
query_lens = torch.tensor([100, 5])
|
||||
|
||||
out = split_indexer_prefill_chunks(seq_lens, query_lens, 2000, 1000)
|
||||
# max_logits_elems = 250, N=1000 -> max_q = 1 -> 100 query sub-chunks
|
||||
expected = [(slice(0, 1), slice(i, i + 1)) for i in range(100)]
|
||||
# req1: M=5, N=50 -> 250 elems fits budget
|
||||
expected.append((slice(1, 2), slice(0, 5)))
|
||||
assert out == expected
|
||||
|
||||
|
||||
def test_triton_convert_returns_valid_counts():
|
||||
"""Test that return_valid_counts correctly counts non-negative indices."""
|
||||
device = torch.device("cuda")
|
||||
|
||||
@@ -20,7 +20,7 @@ def server():
|
||||
"--reasoning-parser",
|
||||
"qwen3",
|
||||
"--reasoning-config",
|
||||
'{"think_start_str": "<think>", "think_end_str": "</think>"}',
|
||||
'{"reasoning_start_str": "<think>", "reasoning_end_str": "</think>"}',
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--enforce-eager",
|
||||
|
||||
@@ -103,8 +103,8 @@ class LogitsProcsRequestParams:
|
||||
class MockReasoningConfig:
|
||||
"""Mock reasoning config for testing ThinkingTokenBudgetLogitsProcessor."""
|
||||
|
||||
think_start_token_ids = [THINK_START_TOKEN_ID]
|
||||
think_end_token_ids = [THINK_END_TOKEN_ID]
|
||||
reasoning_start_token_ids = [THINK_START_TOKEN_ID]
|
||||
reasoning_end_token_ids = [THINK_END_TOKEN_ID]
|
||||
|
||||
|
||||
def _generate_fake_sampling_metadata(
|
||||
@@ -491,7 +491,7 @@ def _thinking_budget_validate(
|
||||
|
||||
# Find if thinking has started in output tokens
|
||||
thinking_started = False
|
||||
start_tokens = tb_processor.think_start_token_ids
|
||||
start_tokens = tb_processor.reasoning_start_token_ids
|
||||
|
||||
if len(start_tokens) > 0:
|
||||
for i in range(len(output_tokens) - len(start_tokens) + 1):
|
||||
@@ -518,7 +518,7 @@ def _thinking_budget_validate(
|
||||
)
|
||||
|
||||
# Validate that only end tokens are allowed
|
||||
end_tokens = tb_processor.think_end_token_ids
|
||||
end_tokens = tb_processor.reasoning_end_token_ids
|
||||
if len(end_tokens) > 0:
|
||||
expected_end_token_id = end_tokens[
|
||||
min(state["end_count"], len(end_tokens) - 1)
|
||||
|
||||
0
tests/v1/simple_kv_offload/__init__.py
Normal file
0
tests/v1/simple_kv_offload/__init__.py
Normal file
193
tests/v1/simple_kv_offload/test_integration.py
Normal file
193
tests/v1/simple_kv_offload/test_integration.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Integration tests for SimpleCPUOffloadConnector with real models."""
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams, TokensPrompt
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip("Requires CUDA", allow_module_level=True)
|
||||
|
||||
# Small models for default CI / local runs (accuracy only).
|
||||
SMALL_MODELS = [
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
"google/gemma-3-1b-it",
|
||||
]
|
||||
|
||||
# Large models for optional perf runs only (slow to load and execute).
|
||||
PERF_MODELS = [
|
||||
"meta-llama/Llama-3.1-8B",
|
||||
"openai/gpt-oss-20b",
|
||||
]
|
||||
|
||||
|
||||
def _make_llm(model: str, lazy: bool, cpu_bytes_to_use: int) -> LLM:
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="SimpleCPUOffloadConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"cpu_bytes_to_use": cpu_bytes_to_use,
|
||||
"lazy_offload": lazy,
|
||||
},
|
||||
)
|
||||
return LLM(
|
||||
model=model,
|
||||
kv_cache_memory_bytes=40 << 30, # 40 GiB
|
||||
disable_hybrid_kv_cache_manager=False,
|
||||
enable_prefix_caching=True,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
)
|
||||
|
||||
|
||||
def _flush_gpu_cache(llm: LLM, sampling_params: SamplingParams, seed: int = 0):
|
||||
"""Generate enough filler requests to allocate the entire GPU KV cache.
|
||||
|
||||
This pushes all prior blocks through the free queue so that the lazy
|
||||
cursor offloads them to CPU before they are evicted.
|
||||
"""
|
||||
cache_config = llm.llm_engine.vllm_config.cache_config
|
||||
num_gpu_blocks = cache_config.num_gpu_blocks
|
||||
block_size = cache_config.block_size
|
||||
# Use 1.2x GPU capacity to give the lazy cursor enough scheduling steps
|
||||
# to walk past all target blocks near the tail of the free queue.
|
||||
total_tokens_needed = int(num_gpu_blocks * block_size * 1.5)
|
||||
|
||||
# Use token-id prompts so each filler is unique (no prefix sharing).
|
||||
# Split into multiple requests to stay under max_model_len.
|
||||
max_tokens_per_req = 4096
|
||||
num_fillers = (total_tokens_needed + max_tokens_per_req - 1) // max_tokens_per_req
|
||||
batch_size = 10
|
||||
for i in range(0, num_fillers, batch_size):
|
||||
batch_end = min(i + batch_size, num_fillers)
|
||||
filler_prompts = []
|
||||
for j in range(i, batch_end):
|
||||
ids = [seed * num_fillers + j + 1] * max_tokens_per_req
|
||||
filler_prompts.append(TokensPrompt(prompt_token_ids=ids))
|
||||
llm.generate(filler_prompts, sampling_params, use_tqdm=False)
|
||||
|
||||
|
||||
def _accuracy_test(llm: LLM, lazy: bool = False):
|
||||
"""Verify that CPU-loaded KV produces correct output."""
|
||||
sampling_params = SamplingParams(max_tokens=1, temperature=0)
|
||||
prompt = "hi " * 2000 + "Let's count to ten. One, two, three, "
|
||||
|
||||
# Cold run — populate GPU cache and trigger CPU offload
|
||||
cold_output = llm.generate(prompt, sampling_params, use_tqdm=False)[0]
|
||||
|
||||
# CPU hit runs
|
||||
test_count = 10
|
||||
success_count = 0
|
||||
expected = cold_output.outputs[0].text
|
||||
for i in range(test_count):
|
||||
if lazy:
|
||||
_flush_gpu_cache(llm, sampling_params, seed=i)
|
||||
time.sleep(2) # let engine core drain pending transfers
|
||||
|
||||
# Reset GPU prefix cache so next run must load from CPU
|
||||
if not llm.reset_prefix_cache():
|
||||
print(f"GPU prefix cache reset failed for iteration {i}")
|
||||
|
||||
output = llm.generate(prompt, sampling_params, use_tqdm=False)[0]
|
||||
if output.outputs[0].text == expected:
|
||||
success_count += 1
|
||||
|
||||
assert success_count >= 0.5 * test_count, (
|
||||
f"Accuracy too low: {success_count}/{test_count} matched '{expected}'"
|
||||
)
|
||||
|
||||
|
||||
def _latency_test(llm: LLM, lazy: bool = False):
|
||||
"""Verify CPU cache hit is faster than cold compute."""
|
||||
sampling_params = SamplingParams(max_tokens=1, seed=42)
|
||||
prompt_token_ids = [0] * 10001
|
||||
|
||||
num_times_cpu_better = 0
|
||||
num_tests = 10
|
||||
for i in range(num_tests):
|
||||
prompt_token_ids[0] = i
|
||||
prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)]
|
||||
|
||||
# Cold
|
||||
time.sleep(2) # let engine core drain pending transfers
|
||||
if not llm.reset_prefix_cache():
|
||||
print(f"GPU prefix cache reset failed for iteration {i}")
|
||||
start = time.time()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
cold_time = time.time() - start
|
||||
|
||||
if lazy:
|
||||
_flush_gpu_cache(llm, sampling_params, seed=i)
|
||||
else:
|
||||
# Eager mode: GPU hit ensures store completion is processed.
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
|
||||
time.sleep(2) # let engine core drain pending transfers
|
||||
if not llm.reset_prefix_cache():
|
||||
print(f"GPU prefix cache reset failed for iteration {i}")
|
||||
|
||||
# CPU hit
|
||||
start = time.time()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
cpu_time = time.time() - start
|
||||
|
||||
if cpu_time < cold_time:
|
||||
num_times_cpu_better += 1
|
||||
|
||||
assert num_times_cpu_better >= 0.8 * num_tests, (
|
||||
f"CPU hit only faster {num_times_cpu_better}/{num_tests} times"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.optional
|
||||
@pytest.mark.slow_test
|
||||
@pytest.mark.parametrize("model", SMALL_MODELS)
|
||||
def test_simple_cpu_offload_accuracy(model: str):
|
||||
"""Store to CPU, reset GPU, load from CPU; verify output matches baseline."""
|
||||
llm = _make_llm(model, False, 1 << 30) # 1GB
|
||||
try:
|
||||
_accuracy_test(llm, lazy=False)
|
||||
finally:
|
||||
del llm
|
||||
|
||||
|
||||
@pytest.mark.optional
|
||||
@pytest.mark.slow_test
|
||||
@pytest.mark.parametrize("model", PERF_MODELS)
|
||||
def test_simple_cpu_offload_perf_latency(model: str):
|
||||
"""CPU KV hit should beat cold prefill on long context (large models only)."""
|
||||
llm = _make_llm(model, False, 10 << 30) # 10GB
|
||||
try:
|
||||
_latency_test(llm, lazy=False)
|
||||
finally:
|
||||
del llm
|
||||
|
||||
|
||||
@pytest.mark.optional
|
||||
@pytest.mark.slow_test
|
||||
@pytest.mark.parametrize("model", SMALL_MODELS)
|
||||
def test_simple_cpu_offload_accuracy_lazy(model: str):
|
||||
"""Lazy mode: flush GPU cache to trigger CPU offload, then verify hit."""
|
||||
# CPU must be larger than GPU KV cache to avoid evicting offloaded blocks.
|
||||
llm = _make_llm(model, True, 80 << 30) # 80GB
|
||||
try:
|
||||
_accuracy_test(llm, lazy=True)
|
||||
finally:
|
||||
del llm
|
||||
|
||||
|
||||
@pytest.mark.optional
|
||||
@pytest.mark.slow_test
|
||||
@pytest.mark.parametrize("model", PERF_MODELS)
|
||||
def test_simple_cpu_offload_perf_latency_lazy(model: str):
|
||||
"""Lazy mode: CPU KV hit should beat cold prefill (large models only)."""
|
||||
# CPU must be larger than GPU KV cache to avoid evicting offloaded blocks.
|
||||
llm = _make_llm(model, True, 80 << 30) # 80GB
|
||||
try:
|
||||
_latency_test(llm, lazy=True)
|
||||
finally:
|
||||
del llm
|
||||
1137
tests/v1/simple_kv_offload/test_scheduler.py
Normal file
1137
tests/v1/simple_kv_offload/test_scheduler.py
Normal file
File diff suppressed because it is too large
Load Diff
147
tests/v1/spec_decode/test_backup_token_async_spec.py
Normal file
147
tests/v1/spec_decode/test_backup_token_async_spec.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Regression tests for the backup token fix in prepare_next_token_ids_padded.
|
||||
|
||||
Fixes #38098: with async scheduling, seq_lens_cpu is inflated by unaccepted
|
||||
draft token placeholders, causing get_token_id() to return -1.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
class _FakeRequest:
|
||||
def __init__(self, prompt_tokens: list[int], output_tokens: list[int]):
|
||||
self.num_prompt_tokens = len(prompt_tokens)
|
||||
self._prompt = prompt_tokens
|
||||
self._output = output_tokens
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
return self.num_prompt_tokens + len(self._output)
|
||||
|
||||
def get_token_id(self, idx: int) -> int:
|
||||
if idx < self.num_prompt_tokens:
|
||||
return self._prompt[idx]
|
||||
out_idx = idx - self.num_prompt_tokens
|
||||
if out_idx < len(self._output):
|
||||
return self._output[out_idx]
|
||||
return -1 # out of range
|
||||
|
||||
|
||||
class _FakeInputBatch:
|
||||
def __init__(
|
||||
self,
|
||||
req_ids: list[str],
|
||||
num_tokens_no_spec: list[int],
|
||||
vocab_size: int = 32000,
|
||||
):
|
||||
self.req_ids = req_ids
|
||||
self.num_reqs = len(req_ids)
|
||||
self.vocab_size = vocab_size
|
||||
self.num_tokens_no_spec = np.array(num_tokens_no_spec, dtype=np.int64)
|
||||
|
||||
|
||||
def _make_requests(
|
||||
req_ids: list[str],
|
||||
prompt_lens: list[int],
|
||||
output_lens: list[int],
|
||||
) -> dict[str, _FakeRequest]:
|
||||
requests = {}
|
||||
for rid, plen, olen in zip(req_ids, prompt_lens, output_lens):
|
||||
requests[rid] = _FakeRequest(list(range(plen)), list(range(1000, 1000 + olen)))
|
||||
return requests
|
||||
|
||||
|
||||
def _backup_buggy(
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
requests: dict[str, _FakeRequest],
|
||||
batch: _FakeInputBatch,
|
||||
) -> list[int]:
|
||||
"""Old logic: uses seq_lens_cpu directly (may be inflated)."""
|
||||
n = batch.num_reqs
|
||||
return [
|
||||
requests[batch.req_ids[i]].get_token_id(int(seq_lens_cpu[i])) for i in range(n)
|
||||
]
|
||||
|
||||
|
||||
def _backup_fixed(
|
||||
requests: dict[str, _FakeRequest],
|
||||
batch: _FakeInputBatch,
|
||||
) -> list[int]:
|
||||
"""New logic: uses num_tokens_no_spec - 1 (last committed token)."""
|
||||
n = batch.num_reqs
|
||||
idx = (batch.num_tokens_no_spec[:n] - 1).tolist()
|
||||
return [requests[batch.req_ids[i]].get_token_id(int(idx[i])) for i in range(n)]
|
||||
|
||||
|
||||
class TestBackupTokenAsyncSpec:
|
||||
def test_no_inflation_fixed_returns_last_token(self):
|
||||
req_ids = ["r0", "r1"]
|
||||
requests = _make_requests(req_ids, [3, 3], [2, 2])
|
||||
batch = _FakeInputBatch(req_ids, [5, 5])
|
||||
# idx = 5-1 = 4 → output[1] = 1001
|
||||
assert _backup_fixed(requests, batch) == [1001, 1001]
|
||||
|
||||
def test_inflation_buggy_returns_placeholder(self):
|
||||
req_ids = ["r0", "r1"]
|
||||
requests = _make_requests(req_ids, [3, 3], [2, 2])
|
||||
batch = _FakeInputBatch(req_ids, [5, 5])
|
||||
# inflated by 3 spec tokens → idx 8 is out of range
|
||||
seq_lens = torch.tensor([8, 8], dtype=torch.int64)
|
||||
assert _backup_buggy(seq_lens, requests, batch) == [-1, -1]
|
||||
|
||||
def test_inflation_fixed_returns_correct_token(self):
|
||||
req_ids = ["r0", "r1"]
|
||||
requests = _make_requests(req_ids, [3, 3], [2, 2])
|
||||
batch = _FakeInputBatch(req_ids, [5, 5])
|
||||
assert _backup_fixed(requests, batch) == [1001, 1001]
|
||||
|
||||
def test_mixed_inflation_per_request(self):
|
||||
req_ids = ["r0", "r1", "r2"]
|
||||
requests = {
|
||||
"r0": _FakeRequest([0, 1], [1000, 1001, 1002]),
|
||||
"r1": _FakeRequest([0, 1, 2, 3], [2000]),
|
||||
"r2": _FakeRequest([0], [3000, 3001, 3002, 3003]),
|
||||
}
|
||||
batch = _FakeInputBatch(req_ids, [5, 5, 5])
|
||||
seq_lens = torch.tensor([7, 9, 5], dtype=torch.int64)
|
||||
|
||||
assert _backup_buggy(seq_lens, requests, batch) == [-1, -1, -1]
|
||||
assert _backup_fixed(requests, batch) == [1002, 2000, 3003]
|
||||
|
||||
def test_prefill_only_request(self):
|
||||
"""No output tokens yet — backup should be the last prompt token."""
|
||||
req_ids = ["r0"]
|
||||
requests = {"r0": _FakeRequest([10, 20, 30], [])}
|
||||
batch = _FakeInputBatch(req_ids, [3])
|
||||
# idx = 3-1 = 2 → prompt[2] = 30
|
||||
assert _backup_fixed(requests, batch) == [30]
|
||||
|
||||
@pytest.mark.parametrize("num_spec_tokens", [1, 2, 3, 4, 5])
|
||||
def test_various_spec_token_counts(self, num_spec_tokens: int):
|
||||
req_ids = ["r0"]
|
||||
requests = {"r0": _FakeRequest([0, 1, 2], list(range(1000, 1005)))}
|
||||
batch = _FakeInputBatch(req_ids, [8])
|
||||
# idx = 8-1 = 7 → output[4] = 1004
|
||||
assert _backup_fixed(requests, batch) == [1004]
|
||||
|
||||
def test_buggy_code_was_always_off_by_one(self):
|
||||
"""The original code used seq_len as index, which is always one past
|
||||
the end of output_token_ids even without async inflation."""
|
||||
req_ids = ["r0"]
|
||||
requests = {"r0": _FakeRequest([0, 1, 2], [1000, 1001])}
|
||||
batch = _FakeInputBatch(req_ids, [5])
|
||||
|
||||
# no inflation: seq_len == num_tokens == 5 → idx 5 is out of range
|
||||
seq_lens = torch.tensor([5], dtype=torch.int64)
|
||||
assert _backup_buggy(seq_lens, requests, batch) == [-1]
|
||||
assert _backup_fixed(requests, batch) == [1001]
|
||||
|
||||
# with inflation: still -1, fixed still correct
|
||||
seq_lens_inf = torch.tensor([8], dtype=torch.int64)
|
||||
assert _backup_buggy(seq_lens_inf, requests, batch) == [-1]
|
||||
assert _backup_fixed(requests, batch) == [1001]
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -111,16 +112,14 @@ def test_prepare_next_token_ids():
|
||||
|
||||
num_requests = 4
|
||||
num_speculative_tokens = 4
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=[num_speculative_tokens + 1] * num_requests,
|
||||
query_lens=[num_speculative_tokens + 1] * num_requests,
|
||||
)
|
||||
|
||||
req_ids = [f"req_{i + 1}" for i in range(num_requests)]
|
||||
mock_input_batch = mock.MagicMock(spec=InputBatch)
|
||||
mock_input_batch.req_ids = req_ids
|
||||
mock_input_batch.num_reqs = num_requests
|
||||
mock_input_batch.vocab_size = 100
|
||||
mock_input_batch.num_tokens_no_spec = np.array(
|
||||
[num_speculative_tokens + 1] * num_requests
|
||||
)
|
||||
|
||||
mock_num_scheduled_tokens = {req_id: 0 for req_id in req_ids}
|
||||
mock_requests = {}
|
||||
@@ -165,19 +164,12 @@ def test_prepare_next_token_ids():
|
||||
|
||||
assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=BLOCK_SIZE,
|
||||
device=device,
|
||||
)
|
||||
|
||||
expected_valid_sampled_tokens_count = torch.tensor(
|
||||
[2, 5, 0, 0], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
next_token_ids_from_padded, valid_sampled_tokens_count = (
|
||||
proposer.prepare_next_token_ids_padded(
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
sampled_token_ids_tensor,
|
||||
mock_requests,
|
||||
mock_input_batch,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -132,16 +133,12 @@ def test_prepare_next_token_ids_padded():
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
num_requests = 4
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=[5] * num_requests,
|
||||
query_lens=[5] * num_requests,
|
||||
)
|
||||
|
||||
req_ids = [f"req_{i + 1}" for i in range(num_requests)]
|
||||
mock_input_batch = mock.MagicMock(spec=InputBatch)
|
||||
mock_input_batch.req_ids = req_ids
|
||||
mock_input_batch.num_reqs = num_requests
|
||||
mock_input_batch.vocab_size = 100
|
||||
mock_input_batch.num_tokens_no_spec = np.array([5] * num_requests)
|
||||
|
||||
mock_requests = {}
|
||||
for req_id in req_ids:
|
||||
@@ -174,12 +171,6 @@ def test_prepare_next_token_ids_padded():
|
||||
|
||||
proposer = _create_proposer(num_speculative_tokens=1)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# valid_sampled_tokens_count tracks if token is valid (not -1 and in vocab range)
|
||||
# It doesn't depend on whether the request is discarded
|
||||
expected_valid_sampled_tokens_count = torch.tensor(
|
||||
@@ -187,7 +178,6 @@ def test_prepare_next_token_ids_padded():
|
||||
)
|
||||
|
||||
next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded(
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
sampled_token_ids,
|
||||
mock_requests,
|
||||
mock_input_batch,
|
||||
|
||||
@@ -12,7 +12,7 @@ from vllm.tokenizers import cached_tokenizer_from_config
|
||||
class ReasoningConfig:
|
||||
"""Configuration for reasoning models.
|
||||
|
||||
Set `think_start_str` and `think_end_str` to the strings that delimit
|
||||
Set `reasoning_start_str` and `reasoning_end_str` to the strings that delimit
|
||||
the reasoning block (e.g. `"<think>"` and `"</think>"`). The
|
||||
corresponding token IDs are derived automatically via
|
||||
`initialize_token_ids` and are not intended to be set directly.
|
||||
@@ -20,53 +20,55 @@ class ReasoningConfig:
|
||||
|
||||
# NOTE: These parameters are temporary, the intent is to derive them
|
||||
# automatically from the reasoning parser in a future version.
|
||||
think_start_str: str = "<think>"
|
||||
reasoning_start_str: str = "<think>"
|
||||
"""String that indicates the start of reasoning."""
|
||||
think_end_str: str = "</think>"
|
||||
reasoning_end_str: str = "</think>"
|
||||
"""String that indicates the end of reasoning content."""
|
||||
|
||||
_think_start_token_ids: list[int] | None = field(
|
||||
_reasoning_start_token_ids: list[int] | None = field(
|
||||
default=None, init=False, repr=False
|
||||
)
|
||||
"""Private backing field for `think_start_token_ids`. Set by
|
||||
"""Private backing field for `reasoning_start_token_ids`. Set by
|
||||
`initialize_token_ids`. Not intended to be configured directly."""
|
||||
_think_end_token_ids: list[int] | None = field(default=None, init=False, repr=False)
|
||||
"""Private backing field for `think_end_token_ids`. Set by
|
||||
_reasoning_end_token_ids: list[int] | None = field(
|
||||
default=None, init=False, repr=False
|
||||
)
|
||||
"""Private backing field for `reasoning_end_token_ids`. Set by
|
||||
`initialize_token_ids`. Not intended to be configured directly."""
|
||||
|
||||
@property
|
||||
def think_start_token_ids(self) -> list[int] | None:
|
||||
"""Token IDs derived from `think_start_str`. Set automatically by
|
||||
def reasoning_start_token_ids(self) -> list[int] | None:
|
||||
"""Token IDs derived from `reasoning_start_str`. Set automatically by
|
||||
`initialize_token_ids`. Not intended to be configured directly."""
|
||||
return self._think_start_token_ids
|
||||
return self._reasoning_start_token_ids
|
||||
|
||||
@property
|
||||
def think_end_token_ids(self) -> list[int] | None:
|
||||
"""Token IDs derived from `think_end_str`. Set automatically by
|
||||
def reasoning_end_token_ids(self) -> list[int] | None:
|
||||
"""Token IDs derived from `reasoning_end_str`. Set automatically by
|
||||
`initialize_token_ids`. Not intended to be configured directly."""
|
||||
return self._think_end_token_ids
|
||||
return self._reasoning_end_token_ids
|
||||
|
||||
def initialize_token_ids(self, model_config: ModelConfig) -> None:
|
||||
"""Initialize reasoning token IDs from strings using the tokenizer."""
|
||||
if (
|
||||
self._think_start_token_ids is not None
|
||||
and self._think_end_token_ids is not None
|
||||
self._reasoning_start_token_ids is not None
|
||||
and self._reasoning_end_token_ids is not None
|
||||
):
|
||||
return
|
||||
|
||||
tokenizer = cached_tokenizer_from_config(model_config=model_config)
|
||||
|
||||
self._think_start_token_ids = tokenizer.encode(
|
||||
self.think_start_str, add_special_tokens=False
|
||||
self._reasoning_start_token_ids = tokenizer.encode(
|
||||
self.reasoning_start_str, add_special_tokens=False
|
||||
)
|
||||
self._think_end_token_ids = tokenizer.encode(
|
||||
self.think_end_str, add_special_tokens=False
|
||||
self._reasoning_end_token_ids = tokenizer.encode(
|
||||
self.reasoning_end_str, add_special_tokens=False
|
||||
)
|
||||
|
||||
if not self._think_start_token_ids or not self._think_end_token_ids:
|
||||
if not self._reasoning_start_token_ids or not self._reasoning_end_token_ids:
|
||||
raise ValueError(
|
||||
f"ReasoningConfig: failed to tokenize reasoning strings: "
|
||||
f"think_start_str='{self.think_start_str}', "
|
||||
f"think_end_str='{self.think_end_str}'. "
|
||||
f"reasoning_start_str='{self.reasoning_start_str}', "
|
||||
f"reasoning_end_str='{self.reasoning_end_str}'. "
|
||||
"Ensure the strings are valid tokens in the model's vocabulary."
|
||||
)
|
||||
|
||||
@@ -657,7 +657,11 @@ class VllmConfig:
|
||||
)
|
||||
|
||||
if kv_offloading_backend == "native":
|
||||
self.kv_transfer_config.kv_connector = "OffloadingConnector"
|
||||
if envs.VLLM_USE_SIMPLE_KV_OFFLOAD:
|
||||
config_connector = "SimpleCPUOffloadConnector"
|
||||
else:
|
||||
config_connector = "OffloadingConnector"
|
||||
self.kv_transfer_config.kv_connector = config_connector
|
||||
self.kv_transfer_config.kv_connector_extra_config.update(
|
||||
{"cpu_bytes_to_use": kv_offloading_size * (1 << 30)}
|
||||
)
|
||||
|
||||
@@ -202,6 +202,7 @@ KVConnectorFactory.register_connector(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector",
|
||||
"DecodeBenchConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector",
|
||||
@@ -213,3 +214,9 @@ KVConnectorFactory.register_connector(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.flexkv_connector",
|
||||
"FlexKVConnectorV1",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"SimpleCPUOffloadConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.simple_cpu_offload_connector",
|
||||
"SimpleCPUOffloadConnector",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,247 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""SimpleCPUOffloadConnector: minimal CPU KV cache offloading."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
SupportsHMA,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
from vllm.v1.simple_kv_offload.manager import (
|
||||
SimpleCPUOffloadScheduler,
|
||||
)
|
||||
from vllm.v1.simple_kv_offload.metadata import (
|
||||
SimpleCPUOffloadMetadata,
|
||||
)
|
||||
from vllm.v1.simple_kv_offload.worker import (
|
||||
SimpleCPUOffloadWorker,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Default CPU capacity: 8 GB
|
||||
DEFAULT_CPU_CAPACITY_BYTES = 8 * (1024**3)
|
||||
|
||||
|
||||
class SimpleCPUOffloadConnector(KVConnectorBase_V1, SupportsHMA):
|
||||
"""CPU KV cache offloading with custom kernel transfers and BlockPool LRU."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
enable_prefix_caching = vllm_config.cache_config.enable_prefix_caching
|
||||
extra_config = self._kv_transfer_config.kv_connector_extra_config or {}
|
||||
|
||||
cpu_capacity_bytes = int(
|
||||
extra_config.get("cpu_bytes_to_use", DEFAULT_CPU_CAPACITY_BYTES)
|
||||
)
|
||||
# cpu_bytes_to_use is server-wide for compatibility;
|
||||
# cpu_bytes_to_use_per_rank overrides for per-rank capacity.
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
cpu_capacity_per_rank = cpu_capacity_bytes // world_size
|
||||
if "cpu_bytes_to_use_per_rank" in extra_config:
|
||||
explicit = int(extra_config["cpu_bytes_to_use_per_rank"])
|
||||
if explicit != cpu_capacity_per_rank:
|
||||
logger.warning(
|
||||
"cpu_bytes_to_use_per_rank (%.2f GB) != "
|
||||
"cpu_bytes_to_use/world_size (%.2f GB). Using per-rank value.",
|
||||
explicit / (1024**3),
|
||||
cpu_capacity_per_rank / (1024**3),
|
||||
)
|
||||
cpu_capacity_per_rank = explicit
|
||||
|
||||
lazy_offload = bool(extra_config.get("lazy_offload", False))
|
||||
|
||||
self.scheduler_manager: SimpleCPUOffloadScheduler | None = None
|
||||
self.worker_handler: SimpleCPUOffloadWorker | None = None
|
||||
|
||||
if not enable_prefix_caching:
|
||||
logger.warning(
|
||||
"Detected prefix caching disabled, disabling CPU offload "
|
||||
"since it requires prefix caching."
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"SimpleCPUOffloadConnector: role=%s, "
|
||||
"per_rank=%.2f GB, world_size=%d, mode=%s",
|
||||
role.name,
|
||||
cpu_capacity_per_rank / (1024**3),
|
||||
world_size,
|
||||
"lazy" if lazy_offload else "eager",
|
||||
)
|
||||
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.scheduler_manager = SimpleCPUOffloadScheduler(
|
||||
vllm_config,
|
||||
kv_cache_config,
|
||||
cpu_capacity_per_rank,
|
||||
lazy_offload=lazy_offload,
|
||||
)
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.worker_handler = SimpleCPUOffloadWorker(
|
||||
vllm_config, kv_cache_config, cpu_capacity_per_rank
|
||||
)
|
||||
|
||||
# --- Worker-side methods ---
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]) -> None:
|
||||
if self.worker_handler is not None:
|
||||
self.worker_handler.register_kv_caches(kv_caches)
|
||||
|
||||
def bind_connector_metadata(
|
||||
self,
|
||||
connector_metadata: KVConnectorMetadata,
|
||||
) -> None:
|
||||
super().bind_connector_metadata(connector_metadata)
|
||||
if self.worker_handler is not None:
|
||||
assert isinstance(connector_metadata, SimpleCPUOffloadMetadata)
|
||||
self.worker_handler.bind_connector_metadata(connector_metadata)
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
super().clear_connector_metadata()
|
||||
if self.worker_handler is not None:
|
||||
self.worker_handler.clear_connector_metadata()
|
||||
|
||||
def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata) -> None:
|
||||
if self.worker_handler is not None:
|
||||
assert isinstance(kv_connector_metadata, SimpleCPUOffloadMetadata)
|
||||
self.worker_handler.handle_preemptions(kv_connector_metadata)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
pass # Launch loads ops in get_finished() after launching model execution
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
pass # Always load asynchronously and deferred to get_finished()
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass # Always save asynchronously and deferred to get_finished()
|
||||
|
||||
def wait_for_save(self) -> None:
|
||||
pass # All stores are driven by get_finished() and no wait needed
|
||||
|
||||
def get_finished(
|
||||
self,
|
||||
finished_req_ids: set[str],
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
if self.worker_handler is not None:
|
||||
return self.worker_handler.get_finished(finished_req_ids)
|
||||
return None, None
|
||||
|
||||
def build_connector_worker_meta(self):
|
||||
if self.worker_handler is not None:
|
||||
return self.worker_handler.build_connector_worker_meta()
|
||||
return None
|
||||
|
||||
# --- Scheduler-side methods ---
|
||||
|
||||
# NOTE: New API only for SimpleCPUOffloadConnector.
|
||||
def bind_gpu_block_pool(self, gpu_block_pool: "BlockPool") -> None:
|
||||
if self.scheduler_manager is not None:
|
||||
self.scheduler_manager.bind_gpu_block_pool(gpu_block_pool)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
if self.scheduler_manager is not None:
|
||||
return self.scheduler_manager.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens
|
||||
)
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self,
|
||||
request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int,
|
||||
) -> None:
|
||||
if self.scheduler_manager is not None:
|
||||
self.scheduler_manager.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens
|
||||
)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
if self.scheduler_manager is not None:
|
||||
return self.scheduler_manager.build_connector_meta(scheduler_output)
|
||||
return SimpleCPUOffloadMetadata()
|
||||
|
||||
def update_connector_output(
|
||||
self,
|
||||
connector_output: KVConnectorOutput,
|
||||
) -> None:
|
||||
if self.scheduler_manager is not None:
|
||||
self.scheduler_manager.update_connector_output(connector_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
if self.scheduler_manager is not None:
|
||||
return self.scheduler_manager.request_finished(request, block_ids)
|
||||
return False, None
|
||||
|
||||
def request_finished_all_groups(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: tuple[list[int], ...],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
if self.scheduler_manager is not None:
|
||||
return self.scheduler_manager.request_finished_all_groups(
|
||||
request, block_ids
|
||||
)
|
||||
return False, None
|
||||
|
||||
# NOTE: New API only for SimpleCPUOffloadConnector.
|
||||
def has_pending_transfers(self) -> bool:
|
||||
if self.scheduler_manager is not None:
|
||||
return self.scheduler_manager.has_pending_stores()
|
||||
return False
|
||||
|
||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||
if self.scheduler_manager is not None:
|
||||
return self.scheduler_manager.take_events()
|
||||
return []
|
||||
|
||||
def reset_cache(self) -> bool | None:
|
||||
raise NotImplementedError(
|
||||
"SimpleCPUOffloadConnector does not support reset_cache(). "
|
||||
"reset_prefix_cache() requires synchronizing all pending "
|
||||
"CPU offload transfers before clearing GPU prefix cache blocks, "
|
||||
"which is not yet implemented."
|
||||
)
|
||||
11
vllm/envs.py
11
vllm/envs.py
@@ -54,6 +54,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ZENTORCH_WEIGHT_PREPACK: bool = True
|
||||
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
|
||||
VLLM_XLA_CHECK_RECOMPILATION: bool = False
|
||||
VLLM_SPARSE_INDEXER_MAX_LOGITS_MB: int = 512
|
||||
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto"
|
||||
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
|
||||
VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True
|
||||
@@ -842,6 +843,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
),
|
||||
# Enable SPMD mode for TPU backend.
|
||||
"VLLM_XLA_USE_SPMD": lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))),
|
||||
# Maximum size (in MB) for logits tensor in sparse MLA indexer prefill chunks.
|
||||
# Bounds the [M, N] float32 logits tensor to prevent CUDA OOM.
|
||||
# Default: 512 MB
|
||||
"VLLM_SPARSE_INDEXER_MAX_LOGITS_MB": lambda: int(
|
||||
os.getenv("VLLM_SPARSE_INDEXER_MAX_LOGITS_MB", "512")
|
||||
),
|
||||
# If set, the OpenAI API server will stay alive even after the underlying
|
||||
# AsyncLLMEngine errors and stops serving requests
|
||||
"VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": lambda: bool(
|
||||
@@ -1655,6 +1662,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_XPU_ENABLE_XPU_GRAPH": lambda: bool(
|
||||
int(os.getenv("VLLM_XPU_ENABLE_XPU_GRAPH", "0"))
|
||||
),
|
||||
# Enable simple KV offload.
|
||||
"VLLM_USE_SIMPLE_KV_OFFLOAD": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_SIMPLE_KV_OFFLOAD", "0"))
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from .dual_chunk_rope import DualChunkRotaryEmbedding
|
||||
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
|
||||
from .dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding
|
||||
from .fope import FourierRotaryEmbedding
|
||||
from .gemma4_rope import Gemma4RotaryEmbedding
|
||||
from .linear_scaling_rope import LinearScalingRotaryEmbedding
|
||||
from .llama3_rope import Llama3RotaryEmbedding
|
||||
from .llama4_vision_rope import Llama4VisionRotaryEmbedding
|
||||
@@ -134,6 +135,17 @@ def get_rope(
|
||||
is_neox_style,
|
||||
dtype,
|
||||
)
|
||||
elif scaling_type == "proportional":
|
||||
# Proportional RoPE is used by Gemma4 for global (full) attention.
|
||||
# Gemma4 uses a sparse/fractional RoPE with cross-mixing between halves.
|
||||
rotary_emb = Gemma4RotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
)
|
||||
elif scaling_type == "llama3":
|
||||
scaling_factor = rope_parameters["factor"]
|
||||
low_freq_factor = rope_parameters["low_freq_factor"]
|
||||
|
||||
84
vllm/model_executor/layers/rotary_embedding/gemma4_rope.py
Normal file
84
vllm/model_executor/layers/rotary_embedding/gemma4_rope.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Gemma4-specific Rotary Positional Embeddings (proportional scaling).
|
||||
|
||||
Gemma4 uses "proportional" RoPE which computes inv_freq frequencies scaled
|
||||
by head_dim (not rotary_dim), and zero-pads for non-rotated dimensions when
|
||||
partial_rotary_factor < 1. The actual rotation uses standard neox-style
|
||||
rotate_half, matching HF transformers' apply_rotary_pos_emb.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from .base import RotaryEmbedding
|
||||
|
||||
|
||||
class Gemma4RotaryEmbedding(RotaryEmbedding):
|
||||
"""Gemma4 proportional RoPE.
|
||||
|
||||
Extends RotaryEmbedding (which provides standard neox-style rotation
|
||||
via ops.rotary_embedding CUDA kernel) but overrides the inv_freq
|
||||
computation to match HF's _compute_proportional_rope_parameters:
|
||||
- Frequency exponents use head_dim (not rotary_dim) as denominator
|
||||
- Non-rotated dims are zero-padded (cos=1, sin=0 = identity rotation)
|
||||
|
||||
When partial_rotary_factor=1.0 (the default for some variants), ALL dims are
|
||||
rotated and this is equivalent to standard RotaryEmbedding with
|
||||
head_dim-scaled frequencies.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
# Number of rotation angle pairs (from partial_rotary_factor)
|
||||
self.rope_angles = rotary_dim // 2
|
||||
# Non-rotated angle pairs per half
|
||||
self.nope_angles = (head_size // 2) - self.rope_angles
|
||||
|
||||
# Important: set rotary_dim = head_size so the base class's
|
||||
# forward_static applies rotation to ALL dims of the cos/sin cache.
|
||||
# The non-rotated dims will have cos=1, sin=0 (identity) thanks
|
||||
# to our _compute_inv_freq zero-padding.
|
||||
super().__init__(
|
||||
head_size,
|
||||
head_size, # rotary_dim = head_size (full application)
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
)
|
||||
|
||||
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
||||
"""Compute frequencies matching HF proportional RoPE.
|
||||
|
||||
Key difference from base: exponent denominator is head_size (not
|
||||
rotary_dim), and non-rotated dims are zero-padded.
|
||||
"""
|
||||
# HF formula: base ** (arange(0, 2*rope_angles, 2) / head_dim)
|
||||
freq_exponents = (
|
||||
torch.arange(0, 2 * self.rope_angles, 2, dtype=torch.float) / self.head_size
|
||||
)
|
||||
inv_freq = 1.0 / (base**freq_exponents)
|
||||
|
||||
# Zero-pad for non-rotated dims (identity rotation: cos=1, sin=0)
|
||||
if self.nope_angles > 0:
|
||||
inv_freq = torch.cat(
|
||||
[
|
||||
inv_freq,
|
||||
torch.zeros(self.nope_angles, dtype=torch.float),
|
||||
]
|
||||
)
|
||||
return inv_freq
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
|
||||
s += f", rope_angles={self.rope_angles}, nope_angles={self.nope_angles}"
|
||||
s += f", max_position_embeddings={self.max_position_embeddings}"
|
||||
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
|
||||
return s
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
@@ -51,6 +52,14 @@ def sparse_attn_indexer(
|
||||
((total_seq_lens, head_dim), torch.float8_e4m3fn),
|
||||
((total_seq_lens, 4), torch.uint8),
|
||||
)
|
||||
|
||||
# Dummy allocation to simulate for peak logits tensor memory during inference.
|
||||
# FP8 elements so elements == bytes
|
||||
max_logits_elems = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024
|
||||
_ = torch.empty(
|
||||
max_logits_elems, dtype=torch.uint8, device=hidden_states.device
|
||||
)
|
||||
|
||||
return sparse_attn_indexer_fake(
|
||||
hidden_states,
|
||||
k_cache_prefix,
|
||||
@@ -101,13 +110,16 @@ def sparse_attn_indexer(
|
||||
for chunk in prefill_metadata.chunks:
|
||||
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
|
||||
k_scale = k_scale_full[: chunk.total_seq_lens]
|
||||
ops.cp_gather_indexer_k_quant_cache(
|
||||
kv_cache,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
chunk.block_table,
|
||||
chunk.cu_seq_lens,
|
||||
)
|
||||
|
||||
if not chunk.skip_kv_gather:
|
||||
ops.cp_gather_indexer_k_quant_cache(
|
||||
kv_cache,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
chunk.block_table,
|
||||
chunk.cu_seq_lens,
|
||||
)
|
||||
|
||||
logits = fp8_mqa_logits(
|
||||
q_fp8[chunk.token_start : chunk.token_end],
|
||||
(k_fp8, k_scale.view(torch.float32).flatten()),
|
||||
|
||||
@@ -14,6 +14,7 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentio
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,6 +58,58 @@ class Gemma3TextModelConfig(VerifyAndUpdateConfig):
|
||||
hf_config.is_causal = not hf_config.use_bidirectional_attention
|
||||
|
||||
|
||||
class Gemma4Config(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
"""Force unified attention backend for models with heterogeneous
|
||||
head dimensions.
|
||||
|
||||
Some Gemma4 variants use different head dimensions for
|
||||
sliding window (head_dim) vs full attention (global_head_dim) layers.
|
||||
When global_head_dim > 256, FlashAttention rejects those layers
|
||||
(head_size <= 256 kernel limit), causing vLLM to select a different
|
||||
backend for each layer type. This mixed-backend execution produces
|
||||
numerical divergence and output corruption.
|
||||
|
||||
The fix detects heterogeneous head dimensions from the model config
|
||||
and forces TRITON_ATTN (which has no head_size ceiling) for all
|
||||
layers when the user hasn't explicitly chosen a backend.
|
||||
|
||||
TODO: Heterogeneous head_sizes (head_dim != global_head_dim)
|
||||
require NixlConnector changes to support per-layer KV transfer
|
||||
with different head dimensions for prefill-decode disaggregation.
|
||||
"""
|
||||
hf_text_config = vllm_config.model_config.hf_text_config
|
||||
head_dim = getattr(hf_text_config, "head_dim", None)
|
||||
global_head_dim = getattr(hf_text_config, "global_head_dim", None)
|
||||
|
||||
# Only force Triton when head dimensions actually differ AND the
|
||||
# larger one exceeds FlashAttention's kernel limit (head_size <= 256).
|
||||
# This avoids unnecessary backend forcing on smaller models where
|
||||
# the config carries global_head_dim but all layers can still use
|
||||
# the same FA backend.
|
||||
max_head_dim = max(head_dim or 0, global_head_dim or 0)
|
||||
if (
|
||||
head_dim is not None
|
||||
and global_head_dim is not None
|
||||
and head_dim != global_head_dim
|
||||
and max_head_dim > 256
|
||||
and vllm_config.attention_config.backend is None
|
||||
):
|
||||
from vllm.v1.attention.backends.registry import (
|
||||
AttentionBackendEnum,
|
||||
)
|
||||
|
||||
vllm_config.attention_config.backend = AttentionBackendEnum.TRITON_ATTN
|
||||
logger.info(
|
||||
"Gemma4 model has heterogeneous head dimensions "
|
||||
"(head_dim=%d, global_head_dim=%d). Forcing TRITON_ATTN "
|
||||
"backend to prevent mixed-backend numerical divergence.",
|
||||
head_dim,
|
||||
global_head_dim,
|
||||
)
|
||||
|
||||
|
||||
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
@@ -668,6 +721,8 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig, # noqa: E501
|
||||
"FalconMambaForCausalLM": MambaModelConfig,
|
||||
"Gemma3TextModel": Gemma3TextModelConfig,
|
||||
"Gemma4ForCausalLM": Gemma4Config,
|
||||
"Gemma4ForConditionalGeneration": Gemma4Config,
|
||||
"GptOssForCausalLM": GptOssForCausalLMConfig,
|
||||
"GteModel": SnowflakeGteNewModelConfig,
|
||||
"GteNewForSequenceClassification": GteNewModelConfig,
|
||||
|
||||
1239
vllm/model_executor/models/gemma4.py
Normal file
1239
vllm/model_executor/models/gemma4.py
Normal file
File diff suppressed because it is too large
Load Diff
1341
vllm/model_executor/models/gemma4_mm.py
Normal file
1341
vllm/model_executor/models/gemma4_mm.py
Normal file
File diff suppressed because it is too large
Load Diff
292
vllm/model_executor/models/gemma4_utils.py
Normal file
292
vllm/model_executor/models/gemma4_utils.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
|
||||
"""Gemma4 output parsing utilities for offline inference.
|
||||
|
||||
Standalone functions that parse decoded model text to extract structured
|
||||
thinking content and tool calls from Gemma4 models. These are pure-Python
|
||||
utilities with zero heavy dependencies — they work on raw decoded strings
|
||||
from any inference backend (vLLM, HuggingFace, TGI, etc.).
|
||||
|
||||
Usage with vLLM offline inference::
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.model_executor.models.gemma4_utils import (
|
||||
parse_output,
|
||||
parse_tool_calls,
|
||||
)
|
||||
|
||||
llm = LLM(model="google/gemma-4-it")
|
||||
outputs = llm.generate(prompt, SamplingParams(...))
|
||||
text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False)
|
||||
|
||||
# Extract thinking / answer (works with or without enable_thinking)
|
||||
result = parse_output(text)
|
||||
print(result["thinking"]) # chain-of-thought or None
|
||||
print(result["answer"]) # final answer
|
||||
|
||||
# Extract tool calls
|
||||
tool_calls = parse_tool_calls(text)
|
||||
for tc in tool_calls:
|
||||
print(f"{tc['name']}({tc['arguments']})")
|
||||
|
||||
Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users
|
||||
do not need a transformers dependency for output parsing.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import regex as re
|
||||
|
||||
# ---- Thinking Mode Utility ----
|
||||
|
||||
# Thinking delimiter tokens as they appear in decoded text.
|
||||
# Gemma4 uses <|channel> (start) and <channel|> (end) as thinking delimiters.
|
||||
_THINKING_START_TAG = "<|channel>"
|
||||
_THINKING_END_TAG = "<channel|>"
|
||||
|
||||
# Sentinel tokens that may appear in decoded output.
|
||||
_TURN_END_TAG = "<turn|>"
|
||||
|
||||
|
||||
def parse_thinking_output(text: str) -> dict[str, str | None]:
|
||||
"""Parse decoded Gemma4 model output.
|
||||
|
||||
Use this on **all** Gemma4 output regardless of whether thinking mode
|
||||
was enabled. It handles three cases:
|
||||
|
||||
1. **Thinking enabled, tags present** — splits on ``<|channel>``/
|
||||
``<channel|>`` to separate chain-of-thought from the answer and
|
||||
strips the ``thought\\n`` role label.
|
||||
2. **Thinking disabled, spurious label** — strips the bare
|
||||
``thought\\n`` prefix that some Gemma4 models emit even
|
||||
without thinking mode.
|
||||
3. **Clean output** — returns the text unchanged.
|
||||
|
||||
The answer text is always cleaned of trailing sentinel tokens
|
||||
(``<turn|>``, ``<eos>``, etc.).
|
||||
|
||||
Args:
|
||||
text: Decoded model output text (from ``tokenizer.decode(...)``).
|
||||
|
||||
Returns:
|
||||
A dict with keys:
|
||||
- ``"thinking"``: The chain-of-thought text, or ``None`` if no
|
||||
thinking delimiters were found.
|
||||
- ``"answer"``: The final answer text.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from vllm.model_executor.models.gemma4_utils import parse_thinking_output
|
||||
>>> output_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
||||
>>> result = parse_thinking_output(output_text)
|
||||
>>> print(result["thinking"]) # chain-of-thought reasoning or None
|
||||
>>> print(result["answer"]) # final answer
|
||||
"""
|
||||
if _THINKING_END_TAG in text:
|
||||
parts = text.split(_THINKING_END_TAG, 1)
|
||||
thinking_block = parts[0]
|
||||
answer = _clean_answer(parts[1])
|
||||
|
||||
# Extract thinking content: strip the start tag if present
|
||||
if _THINKING_START_TAG in thinking_block:
|
||||
thinking = thinking_block.split(_THINKING_START_TAG, 1)[1]
|
||||
else:
|
||||
thinking = thinking_block
|
||||
|
||||
# Strip the "thought\n" channel role label the model emits inside
|
||||
# <|channel>thought\n...<channel|> (analogous to "user\n" in
|
||||
# <|turn>user\n...<turn|>).
|
||||
thinking = _strip_thought_label(thinking.strip())
|
||||
thinking = thinking.strip()
|
||||
|
||||
return {"thinking": thinking, "answer": answer}
|
||||
|
||||
# No thinking delimiters found.
|
||||
# Strip spurious "thought\n" role label that some Gemma4 models sometimes
|
||||
# emit even without thinking mode enabled, then clean trailing tokens.
|
||||
answer = _strip_thought_label(text)
|
||||
answer = _clean_answer(answer)
|
||||
return {"thinking": None, "answer": answer}
|
||||
|
||||
|
||||
def _strip_thought_label(text: str) -> str:
|
||||
"""Strip the spurious ``thought\\n`` label from the start of text.
|
||||
|
||||
Only strips when ``thought`` appears as the very first word followed by
|
||||
a newline — preserving the word ``thought`` in any other context.
|
||||
"""
|
||||
if text.startswith("thought\n"):
|
||||
return text[len("thought\n") :]
|
||||
return text
|
||||
|
||||
|
||||
def _clean_answer(text: str) -> str:
|
||||
"""Clean trailing sentinel tokens from the answer text.
|
||||
|
||||
Strips ``<turn|>``, ``<eos>``, and surrounding whitespace that the
|
||||
model appends at the end of its response.
|
||||
"""
|
||||
text = text.strip()
|
||||
# Strip trailing <turn|> (Gemma4 turn-end marker)
|
||||
if text.endswith(_TURN_END_TAG):
|
||||
text = text[: -len(_TURN_END_TAG)].rstrip()
|
||||
# Strip trailing <eos> if present
|
||||
if text.endswith("<eos>"):
|
||||
text = text[:-5].rstrip()
|
||||
return text
|
||||
|
||||
|
||||
# ---- Tool Call Parsing Utility ----
|
||||
#
|
||||
# NOTE: For the OpenAI-compatible API server tool parser (streaming +
|
||||
# non-streaming), see vllm/tool_parsers/gemma4_tool_parser.py.
|
||||
# This module provides offline inference utilities for direct user import.
|
||||
|
||||
# Tool call delimiter tokens as they appear in decoded text.
|
||||
# Standard format: <|tool_call>call:name{args}<tool_call|>
|
||||
_TOOL_CALL_START_TAG = "<|tool_call>"
|
||||
_TOOL_CALL_END_TAG = "<tool_call|>"
|
||||
_TOOL_RESPONSE_START_TAG = "<|tool_response>"
|
||||
|
||||
# Gemma4 escape token as it appears in decoded text.
|
||||
_ESCAPE_TOKEN = '<|"|>'
|
||||
|
||||
|
||||
def _parse_tool_arguments(args_str: str) -> dict[str, str]:
|
||||
"""Parse tool call arguments from the Gemma4 compact format.
|
||||
|
||||
Handles the ``key:<|"|>value<|"|>`` format used by Gemma4, with fallback
|
||||
to heuristic key-value extraction. Also tolerates the slightly different
|
||||
``key: "value"`` format (space + plain quotes) that some chat templates
|
||||
produce.
|
||||
|
||||
Args:
|
||||
args_str: Raw argument string from inside ``call:name{...}``.
|
||||
|
||||
Returns:
|
||||
Dictionary of argument name → value.
|
||||
"""
|
||||
if not args_str or not args_str.strip():
|
||||
return {}
|
||||
|
||||
# Replace Gemma4 escape tokens with standard quotes.
|
||||
cleaned = args_str.replace(_ESCAPE_TOKEN, '"')
|
||||
|
||||
# Try JSON parsing first (handles nested values, arrays, etc.).
|
||||
try:
|
||||
parsed = json.loads("{" + cleaned + "}")
|
||||
# Ensure all values are strings for consistency.
|
||||
return {k: str(v) if not isinstance(v, str) else v for k, v in parsed.items()}
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# Fallback: extract key:"value" pairs (allow optional space after colon).
|
||||
arguments = {}
|
||||
for key, value in re.findall(r'(\w+):\s*"([^"]*)"', cleaned):
|
||||
arguments[key] = value
|
||||
|
||||
if not arguments:
|
||||
# Last resort: extract key:value pairs (unquoted).
|
||||
for key, value in re.findall(r"(\w+):\s*([^,}]+)", args_str):
|
||||
arguments[key] = value.strip().strip('"').replace(_ESCAPE_TOKEN, "")
|
||||
|
||||
return arguments
|
||||
|
||||
|
||||
def parse_tool_calls(text: str, *, strict: bool = False) -> list[dict]:
|
||||
"""Parse tool calls from decoded Gemma4 model output.
|
||||
|
||||
Uses a tiered parsing strategy to handle known output variations in
|
||||
Gemma4 models, which may emit
|
||||
non-standard tool call formats.
|
||||
|
||||
Parsing tiers:
|
||||
1. **Standard**: ``<|tool_call>call:name{args}<tool_call|>``
|
||||
(special token IDs 48/49 in decoded text)
|
||||
2. **Fallback** (when ``strict=False``): bare ``call:name{args}``
|
||||
patterns, including ``<call>name{args}`` (fragmented tokens from
|
||||
multimodal inputs)
|
||||
|
||||
Args:
|
||||
text: Decoded model output text (from ``tokenizer.decode(...,
|
||||
skip_special_tokens=False)``).
|
||||
strict: If ``True``, only match the standard ``<|tool_call>`` format.
|
||||
If ``False`` (default), also try fallback patterns for
|
||||
known Gemma4 output variations.
|
||||
|
||||
Returns:
|
||||
A list of dicts, each with keys:
|
||||
- ``"name"``: The tool function name (e.g. ``"get_weather"``).
|
||||
- ``"arguments"``: A dict of argument name → value.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from vllm.model_executor.models.gemma4_utils import (
|
||||
... parse_tool_calls
|
||||
... )
|
||||
>>> output = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
||||
>>> tool_calls = parse_tool_calls(output)
|
||||
>>> for tc in tool_calls:
|
||||
... print(f"Call: {tc['name']}({tc['arguments']})")
|
||||
"""
|
||||
results = []
|
||||
|
||||
# Tier 1: Standard format with special tokens.
|
||||
# <|tool_call>call:name{args}<tool_call|>
|
||||
# Note: Some Gemma4 models emit <turn|> instead of <tool_call|>.
|
||||
standard_pattern = r"<\|tool_call\>call:(\w+)\{(.*?)\}(?:<tool_call\|>|<turn\|>)"
|
||||
for match in re.finditer(standard_pattern, text, re.DOTALL):
|
||||
name, args_str = match.group(1), match.group(2)
|
||||
results.append(
|
||||
{
|
||||
"name": name,
|
||||
"arguments": _parse_tool_arguments(args_str),
|
||||
}
|
||||
)
|
||||
|
||||
if results or strict:
|
||||
return results
|
||||
|
||||
# Tier 2: Fallback for known Gemma4 output variations.
|
||||
# Matches: <call>name{args}, call:name{args}, or bare call:name{args}<eos>
|
||||
fallback_pattern = r"(?:<call>|(?:^|\s)call:)(\w+)\{(.*?)\}"
|
||||
for match in re.finditer(fallback_pattern, text, re.DOTALL):
|
||||
name, args_str = match.group(1), match.group(2)
|
||||
results.append(
|
||||
{
|
||||
"name": name,
|
||||
"arguments": _parse_tool_arguments(args_str),
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def has_tool_response_tag(text: str) -> bool:
|
||||
"""Check if model output properly ends with a tool response tag.
|
||||
|
||||
Some Gemma4 models sometimes emit ``<eos>`` instead of
|
||||
``<|tool_response>`` after a tool call. This helper detects
|
||||
whether the model used the proper termination, so callers can
|
||||
decide whether to inject ``<|tool_response>`` into the next prompt.
|
||||
|
||||
Args:
|
||||
text: Decoded model output text.
|
||||
|
||||
Returns:
|
||||
``True`` if the output ends with ``<|tool_response>``
|
||||
(proper behavior), ``False`` otherwise.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from vllm.model_executor.models.gemma4_utils import (
|
||||
... has_tool_response_tag
|
||||
... )
|
||||
>>> if not has_tool_response_tag(model_output):
|
||||
... # Model used <eos> instead — inject <|tool_response> manually
|
||||
... next_prompt = "<|tool_response>" + tool_result
|
||||
"""
|
||||
stripped = text.rstrip()
|
||||
return stripped.endswith(_TOOL_RESPONSE_START_TAG)
|
||||
@@ -111,6 +111,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
|
||||
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
|
||||
"Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
|
||||
"Gemma4ForCausalLM": ("gemma4", "Gemma4ForCausalLM"),
|
||||
"Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
|
||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
|
||||
@@ -381,6 +382,7 @@ _MULTIMODAL_MODELS = {
|
||||
"gemma3n_mm",
|
||||
"Gemma3nForConditionalGeneration",
|
||||
),
|
||||
"Gemma4ForConditionalGeneration": ("gemma4_mm", "Gemma4ForConditionalGeneration"),
|
||||
"GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
|
||||
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
|
||||
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
|
||||
|
||||
@@ -233,8 +233,15 @@ class AutoWeightsLoader:
|
||||
):
|
||||
"""
|
||||
Add tensor names that are not in the model params that may be in the
|
||||
safetensors, e.g., batch normalization stats.
|
||||
safetensors, e.g., batch normalization stats and registered buffers.
|
||||
"""
|
||||
# Add persistent registered buffers.
|
||||
# Non-persistent buffers are excluded, matching PyTorch state_dict().
|
||||
non_persistent = getattr(module, "_non_persistent_buffers_set", set())
|
||||
for buf_name, buf in module.named_buffers(recurse=False):
|
||||
if buf_name not in child_params and buf_name not in non_persistent:
|
||||
child_params[buf_name] = buf
|
||||
|
||||
if isinstance(
|
||||
module,
|
||||
(
|
||||
|
||||
@@ -80,8 +80,15 @@ class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]):
|
||||
"image/jpeg",
|
||||
)
|
||||
|
||||
if self.num_frames > 0:
|
||||
frame_parts = data.split(",", self.num_frames)[: self.num_frames]
|
||||
elif self.num_frames == 0:
|
||||
raise ValueError("num_frames must be greater than 0 or -1")
|
||||
else:
|
||||
frame_parts = data.split(",")
|
||||
|
||||
frames = np.stack(
|
||||
[np.asarray(load_frame(frame_data)) for frame_data in data.split(",")]
|
||||
[np.asarray(load_frame(frame_data)) for frame_data in frame_parts]
|
||||
)
|
||||
total = int(frames.shape[0])
|
||||
fps = float(self.kwargs.get("fps", 1))
|
||||
|
||||
@@ -32,6 +32,10 @@ _REASONING_PARSERS_TO_REGISTER = {
|
||||
"ernie45_reasoning_parser",
|
||||
"Ernie45ReasoningParser",
|
||||
),
|
||||
"gemma4": (
|
||||
"gemma4_reasoning_parser",
|
||||
"Gemma4ReasoningParser",
|
||||
),
|
||||
"glm45": (
|
||||
"deepseek_v3_reasoning_parser",
|
||||
"DeepSeekV3ReasoningWithThinkingParser",
|
||||
|
||||
193
vllm/reasoning/gemma4_reasoning_parser.py
Normal file
193
vllm/reasoning/gemma4_reasoning_parser.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
||||
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
|
||||
# Role label that Gemma4 emits at the start of the thinking channel.
|
||||
# The model generates: <|channel>thought\n...reasoning...<channel|>
|
||||
# This prefix must be stripped to expose only the actual reasoning content.
|
||||
_THOUGHT_PREFIX = "thought\n"
|
||||
|
||||
|
||||
class Gemma4ReasoningParser(BaseThinkingReasoningParser):
|
||||
"""
|
||||
Reasoning parser for Google Gemma4 thinking models.
|
||||
|
||||
Gemma4 uses <|channel>...<channel|> tokens to delimit reasoning/thinking
|
||||
content within its output. Thinking mode is activated by passing
|
||||
``enable_thinking=True`` in the chat template kwargs, which injects a
|
||||
system turn containing <|think|> (token 98) to trigger chain-of-thought
|
||||
reasoning.
|
||||
|
||||
Output pattern when thinking is enabled::
|
||||
|
||||
<|channel>thought
|
||||
...chain of thought reasoning...<channel|>
|
||||
Final answer text here.
|
||||
|
||||
The ``thought\\n`` role label inside the channel delimiters is a
|
||||
structural artefact (analogous to ``user\\n`` in ``<|turn>user\\n...``).
|
||||
This parser strips it so that downstream consumers see only the
|
||||
actual reasoning text, consistent with the offline parser
|
||||
(``vllm.reasoning.gemma4_utils._strip_thought_label``).
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
|
||||
super().__init__(tokenizer, *args, **kwargs)
|
||||
# Instance state for streaming prefix stripping.
|
||||
# Tracks only the reasoning text received from the base parser,
|
||||
# independent of current_text (which may contain pre-reasoning
|
||||
# content and lacks special token text due to
|
||||
# skip_special_tokens=True).
|
||||
self._reasoning_text: str = ""
|
||||
self._prefix_stripped: bool = False
|
||||
|
||||
@property
|
||||
def start_token(self) -> str:
|
||||
"""The token that starts reasoning content."""
|
||||
return "<|channel>"
|
||||
|
||||
@property
|
||||
def end_token(self) -> str:
|
||||
"""The token that ends reasoning content."""
|
||||
return "<channel|>"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Non-streaming path
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def extract_reasoning(
|
||||
self,
|
||||
model_output: str,
|
||||
request: "ChatCompletionRequest | ResponsesRequest",
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Extract reasoning, stripping the ``thought\\n`` role label."""
|
||||
if self.start_token not in model_output and self.end_token not in model_output:
|
||||
# Default to content history if no tags are present
|
||||
# (or if they were stripped)
|
||||
return None, model_output
|
||||
|
||||
reasoning, content = super().extract_reasoning(model_output, request)
|
||||
if reasoning is not None:
|
||||
reasoning = _strip_thought_label(reasoning)
|
||||
return reasoning, content
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Streaming path
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def extract_reasoning_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> DeltaMessage | None:
|
||||
"""Extract streaming reasoning, stripping ``thought\\n`` from the
|
||||
first reasoning delta(s).
|
||||
|
||||
The ``thought\\n`` prefix may arrive as a single delta or split
|
||||
across multiple deltas (e.g. ``"thought"`` then ``"\\n"``). We
|
||||
buffer early reasoning tokens until we can determine whether the
|
||||
prefix is present, then emit the buffered content minus the
|
||||
prefix.
|
||||
|
||||
Unlike the previous implementation which reconstructed accumulated
|
||||
reasoning from ``current_text``, this uses instance state
|
||||
(``_reasoning_text``) to track only the reasoning content returned
|
||||
by the base parser. This is necessary because
|
||||
``skip_special_tokens=True`` (the vLLM default) causes the
|
||||
``<|channel>`` delimiter to be invisible in ``current_text``,
|
||||
making it impossible to separate pre-reasoning content from
|
||||
reasoning content via string matching.
|
||||
"""
|
||||
result = super().extract_reasoning_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
delta_token_ids,
|
||||
)
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
if result.reasoning is None:
|
||||
return result
|
||||
|
||||
# Accumulate ONLY the reasoning text from base parser results.
|
||||
# This is immune to pre-reasoning content pollution.
|
||||
self._reasoning_text += result.reasoning
|
||||
|
||||
# Once the prefix has been handled, all subsequent reasoning
|
||||
# deltas pass through unchanged.
|
||||
if self._prefix_stripped:
|
||||
return result
|
||||
|
||||
# ---- Prefix stripping logic ----
|
||||
|
||||
# Case 1: We've accumulated enough to confirm the prefix is
|
||||
# present. Strip it and pass through the remainder.
|
||||
if self._reasoning_text.startswith(_THOUGHT_PREFIX):
|
||||
prefix_len = len(_THOUGHT_PREFIX)
|
||||
# How much reasoning was accumulated before this delta?
|
||||
prev_reasoning_len = len(self._reasoning_text) - len(result.reasoning)
|
||||
if prev_reasoning_len >= prefix_len:
|
||||
# Prefix was already consumed by prior deltas; this
|
||||
# delta is entirely real content — pass through.
|
||||
self._prefix_stripped = True
|
||||
return result
|
||||
else:
|
||||
# Part or all of the prefix is in this delta.
|
||||
chars_of_prefix_in_delta = prefix_len - prev_reasoning_len
|
||||
stripped = result.reasoning[chars_of_prefix_in_delta:]
|
||||
if stripped:
|
||||
self._prefix_stripped = True
|
||||
result.reasoning = stripped
|
||||
return result
|
||||
else:
|
||||
# This entire delta was prefix — suppress it.
|
||||
# Don't set _prefix_stripped yet; there may be more
|
||||
# prefix chars to consume in the next delta.
|
||||
if len(self._reasoning_text) >= prefix_len:
|
||||
self._prefix_stripped = True
|
||||
return None
|
||||
|
||||
# Case 2: Accumulated text is a strict prefix of
|
||||
# _THOUGHT_PREFIX (e.g. we've only seen "thou" so far).
|
||||
# Buffer by suppressing — we can't yet tell if this will
|
||||
# become the full prefix or diverge.
|
||||
if _THOUGHT_PREFIX.startswith(self._reasoning_text):
|
||||
return None
|
||||
|
||||
# Case 3: Accumulated text doesn't match the thought prefix
|
||||
# at all. This means prior deltas were buffered (suppressed
|
||||
# by Case 2) but the text diverged. Re-emit the full
|
||||
# accumulated text to avoid data loss.
|
||||
self._prefix_stripped = True
|
||||
result.reasoning = self._reasoning_text
|
||||
return result
|
||||
|
||||
|
||||
def _strip_thought_label(text: str) -> str:
|
||||
"""Remove the ``thought\\n`` role label from the beginning of text.
|
||||
|
||||
Mirrors ``vllm.reasoning.gemma4_utils._strip_thought_label`` from the
|
||||
offline parser.
|
||||
"""
|
||||
if text.startswith(_THOUGHT_PREFIX):
|
||||
return text[len(_THOUGHT_PREFIX) :]
|
||||
return text
|
||||
130
vllm/reasoning/gemma4_utils.py
Normal file
130
vllm/reasoning/gemma4_utils.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
|
||||
"""Gemma4 thinking/reasoning output parsing utilities for offline inference.
|
||||
|
||||
Standalone functions that parse decoded model text to extract structured
|
||||
thinking content from Gemma4 models. These are pure-Python utilities with
|
||||
zero heavy dependencies — they work on raw decoded strings from any
|
||||
inference backend (vLLM, HuggingFace, TGI, etc.).
|
||||
|
||||
For the OpenAI-compatible API reasoning parser (streaming +
|
||||
non-streaming), see ``vllm.reasoning.gemma4_reasoning_parser``.
|
||||
For tool call parsing, see ``vllm.tool_parsers.gemma4_utils``.
|
||||
|
||||
Usage with vLLM offline inference::
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.reasoning.gemma4_utils import parse_thinking_output
|
||||
|
||||
llm = LLM(model="google/gemma-4-it")
|
||||
outputs = llm.generate(prompt, SamplingParams(...))
|
||||
text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False)
|
||||
|
||||
# Extract thinking / answer (works with or without enable_thinking)
|
||||
result = parse_thinking_output(text)
|
||||
print(result["thinking"]) # chain-of-thought or None
|
||||
print(result["answer"]) # final answer
|
||||
|
||||
Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users
|
||||
do not need a transformers dependency for output parsing.
|
||||
"""
|
||||
|
||||
# ---- Thinking Mode Utility ----
|
||||
|
||||
# Thinking delimiter tokens as they appear in decoded text.
|
||||
# Gemma4 uses <|channel> (start) and <channel|> (end) as thinking delimiters.
|
||||
_THINKING_START_TAG = "<|channel>"
|
||||
_THINKING_END_TAG = "<channel|>"
|
||||
|
||||
# Sentinel tokens that may appear in decoded output.
|
||||
_TURN_END_TAG = "<turn|>"
|
||||
|
||||
|
||||
def parse_thinking_output(text: str) -> dict[str, str | None]:
|
||||
"""Parse decoded Gemma4 model output.
|
||||
|
||||
Use this on **all** Gemma4 output regardless of whether thinking mode
|
||||
was enabled. It handles three cases:
|
||||
|
||||
1. **Thinking enabled, tags present** — splits on ``<|channel>``/
|
||||
``<channel|>`` to separate chain-of-thought from the answer and
|
||||
strips the ``thought\\n`` role label.
|
||||
2. **Thinking disabled, spurious label** — strips the bare
|
||||
``thought\\n`` prefix that some Gemma4 models emit even
|
||||
without thinking mode.
|
||||
3. **Clean output** — returns the text unchanged.
|
||||
|
||||
The answer text is always cleaned of trailing sentinel tokens
|
||||
(``<turn|>``, ``<eos>``, etc.).
|
||||
|
||||
Args:
|
||||
text: Decoded model output text (from ``tokenizer.decode(...)``).
|
||||
|
||||
Returns:
|
||||
A dict with keys:
|
||||
- ``"thinking"``: The chain-of-thought text, or ``None`` if no
|
||||
thinking delimiters were found.
|
||||
- ``"answer"``: The final answer text.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from vllm.reasoning.gemma4_utils import parse_thinking_output
|
||||
>>> output_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
||||
>>> result = parse_thinking_output(output_text)
|
||||
>>> print(result["thinking"]) # chain-of-thought reasoning or None
|
||||
>>> print(result["answer"]) # final answer
|
||||
"""
|
||||
if _THINKING_END_TAG in text:
|
||||
parts = text.split(_THINKING_END_TAG, 1)
|
||||
thinking_block = parts[0]
|
||||
answer = _clean_answer(parts[1])
|
||||
|
||||
# Extract thinking content: strip the start tag if present
|
||||
if _THINKING_START_TAG in thinking_block:
|
||||
thinking = thinking_block.split(_THINKING_START_TAG, 1)[1]
|
||||
else:
|
||||
thinking = thinking_block
|
||||
|
||||
# Strip the "thought\n" channel role label the model emits inside
|
||||
# <|channel>thought\n...<channel|> (analogous to "user\n" in
|
||||
# <|turn>user\n...<turn|>).
|
||||
thinking = _strip_thought_label(thinking.strip())
|
||||
thinking = thinking.strip()
|
||||
|
||||
return {"thinking": thinking, "answer": answer}
|
||||
|
||||
# No thinking delimiters found.
|
||||
# Strip spurious "thought\n" role label that some Gemma4 models sometimes
|
||||
# emit even without thinking mode enabled, then clean trailing tokens.
|
||||
answer = _strip_thought_label(text)
|
||||
answer = _clean_answer(answer)
|
||||
return {"thinking": None, "answer": answer}
|
||||
|
||||
|
||||
def _strip_thought_label(text: str) -> str:
|
||||
"""Strip the spurious ``thought\\n`` label from the start of text.
|
||||
|
||||
Only strips when ``thought`` appears as the very first word followed by
|
||||
a newline — preserving the word ``thought`` in any other context.
|
||||
"""
|
||||
if text.startswith("thought\n"):
|
||||
return text[len("thought\n") :]
|
||||
return text
|
||||
|
||||
|
||||
def _clean_answer(text: str) -> str:
|
||||
"""Clean trailing sentinel tokens from the answer text.
|
||||
|
||||
Strips ``<turn|>``, ``<eos>``, and surrounding whitespace that the
|
||||
model appends at the end of its response.
|
||||
"""
|
||||
text = text.strip()
|
||||
# Strip trailing <turn|> (Gemma4 turn-end marker)
|
||||
if text.endswith(_TURN_END_TAG):
|
||||
text = text[: -len(_TURN_END_TAG)].rstrip()
|
||||
# Strip trailing <eos> if present
|
||||
if text.endswith("<eos>"):
|
||||
text = text[:-5].rstrip()
|
||||
return text
|
||||
@@ -154,6 +154,10 @@ _TOOL_PARSERS_TO_REGISTER = {
|
||||
"functiongemma_tool_parser",
|
||||
"FunctionGemmaToolParser",
|
||||
),
|
||||
"gemma4": (
|
||||
"gemma4_tool_parser",
|
||||
"Gemma4ToolParser",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
||||
724
vllm/tool_parsers/gemma4_tool_parser.py
Normal file
724
vllm/tool_parsers/gemma4_tool_parser.py
Normal file
@@ -0,0 +1,724 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tool call parser for Google Gemma4 models.
|
||||
|
||||
Gemma4 uses a custom serialization format (not JSON) for tool calls::
|
||||
|
||||
<|tool_call>call:func_name{key:<|"|>value<|"|>,num:42}<tool_call|>
|
||||
|
||||
Strings are delimited by ``<|"|>`` (token 52), keys are unquoted, and
|
||||
multiple tool calls are concatenated without separators.
|
||||
|
||||
Used when ``--enable-auto-tool-choice --tool-call-parser gemma4`` are set.
|
||||
|
||||
For offline inference tool call parsing (direct ``tokenizer.decode()`` output),
|
||||
see ``vllm.tool_parsers.gemma4_utils.parse_tool_calls``.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
|
||||
from vllm.tool_parsers.utils import find_common_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Gemma4 special tokens for tool calls
|
||||
TOOL_CALL_START = "<|tool_call>"
|
||||
TOOL_CALL_END = "<tool_call|>"
|
||||
STRING_DELIM = '<|"|>'
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gemma4 argument parser (used by both streaming and non-streaming paths)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_gemma4_value(value_str: str) -> object:
|
||||
"""Parse a single Gemma4 value (after key:) into a Python object."""
|
||||
value_str = value_str.strip()
|
||||
if not value_str:
|
||||
return value_str
|
||||
|
||||
# Boolean
|
||||
if value_str == "true":
|
||||
return True
|
||||
if value_str == "false":
|
||||
return False
|
||||
|
||||
# Number (int or float)
|
||||
try:
|
||||
if "." in value_str:
|
||||
return float(value_str)
|
||||
return int(value_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Bare string (no <|"|> delimiters — shouldn't happen but be safe)
|
||||
return value_str
|
||||
|
||||
|
||||
def _parse_gemma4_args(args_str: str) -> dict:
|
||||
"""Parse Gemma4's custom key:value format into a Python dict.
|
||||
|
||||
Format examples::
|
||||
|
||||
location:<|"|>Tokyo<|"|>
|
||||
location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>
|
||||
count:42,flag:true
|
||||
nested:{inner_key:<|"|>val<|"|>}
|
||||
items:[<|"|>a<|"|>,<|"|>b<|"|>]
|
||||
|
||||
Returns a dict ready for ``json.dumps()``.
|
||||
"""
|
||||
if not args_str or not args_str.strip():
|
||||
return {}
|
||||
|
||||
result: dict = {}
|
||||
i = 0
|
||||
n = len(args_str)
|
||||
|
||||
while i < n:
|
||||
# Skip whitespace and commas
|
||||
while i < n and args_str[i] in (" ", ",", "\n", "\t"):
|
||||
i += 1
|
||||
if i >= n:
|
||||
break
|
||||
|
||||
# Parse key (unquoted, ends at ':')
|
||||
key_start = i
|
||||
while i < n and args_str[i] != ":":
|
||||
i += 1
|
||||
if i >= n:
|
||||
break
|
||||
key = args_str[key_start:i].strip()
|
||||
i += 1 # skip ':'
|
||||
|
||||
# Parse value
|
||||
if i >= n:
|
||||
result[key] = ""
|
||||
break
|
||||
|
||||
# Skip whitespace after ':'
|
||||
while i < n and args_str[i] in (" ", "\n", "\t"):
|
||||
i += 1
|
||||
if i >= n:
|
||||
result[key] = ""
|
||||
break
|
||||
|
||||
# String value: <|"|>...<|"|>
|
||||
if args_str[i:].startswith(STRING_DELIM):
|
||||
i += len(STRING_DELIM)
|
||||
val_start = i
|
||||
end_pos = args_str.find(STRING_DELIM, i)
|
||||
if end_pos == -1:
|
||||
# Unterminated string — take rest
|
||||
result[key] = args_str[val_start:]
|
||||
break
|
||||
result[key] = args_str[val_start:end_pos]
|
||||
i = end_pos + len(STRING_DELIM)
|
||||
|
||||
# Nested object: {...}
|
||||
elif args_str[i] == "{":
|
||||
depth = 1
|
||||
obj_start = i + 1
|
||||
i += 1
|
||||
while i < n and depth > 0:
|
||||
if args_str[i:].startswith(STRING_DELIM):
|
||||
# Skip over string contents to avoid counting { inside strings
|
||||
i += len(STRING_DELIM)
|
||||
next_delim = args_str.find(STRING_DELIM, i)
|
||||
i = n if next_delim == -1 else next_delim + len(STRING_DELIM)
|
||||
continue
|
||||
if args_str[i] == "{":
|
||||
depth += 1
|
||||
elif args_str[i] == "}":
|
||||
depth -= 1
|
||||
i += 1
|
||||
result[key] = _parse_gemma4_args(args_str[obj_start : i - 1])
|
||||
|
||||
# Array: [...]
|
||||
elif args_str[i] == "[":
|
||||
depth = 1
|
||||
arr_start = i + 1
|
||||
i += 1
|
||||
while i < n and depth > 0:
|
||||
if args_str[i:].startswith(STRING_DELIM):
|
||||
i += len(STRING_DELIM)
|
||||
next_delim = args_str.find(STRING_DELIM, i)
|
||||
i = n if next_delim == -1 else next_delim + len(STRING_DELIM)
|
||||
continue
|
||||
if args_str[i] == "[":
|
||||
depth += 1
|
||||
elif args_str[i] == "]":
|
||||
depth -= 1
|
||||
i += 1
|
||||
arr_content = args_str[arr_start : i - 1]
|
||||
result[key] = _parse_gemma4_array(arr_content)
|
||||
|
||||
# Bare value (number, boolean, etc.)
|
||||
else:
|
||||
val_start = i
|
||||
while i < n and args_str[i] not in (",", "}", "]"):
|
||||
i += 1
|
||||
result[key] = _parse_gemma4_value(args_str[val_start:i])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _parse_gemma4_array(arr_str: str) -> list:
|
||||
"""Parse a Gemma4 array content string into a Python list."""
|
||||
items: list = []
|
||||
i = 0
|
||||
n = len(arr_str)
|
||||
|
||||
while i < n:
|
||||
while i < n and arr_str[i] in (" ", ",", "\n", "\t"):
|
||||
i += 1
|
||||
if i >= n:
|
||||
break
|
||||
|
||||
# String element
|
||||
if arr_str[i:].startswith(STRING_DELIM):
|
||||
i += len(STRING_DELIM)
|
||||
end_pos = arr_str.find(STRING_DELIM, i)
|
||||
if end_pos == -1:
|
||||
items.append(arr_str[i:])
|
||||
break
|
||||
items.append(arr_str[i:end_pos])
|
||||
i = end_pos + len(STRING_DELIM)
|
||||
|
||||
# Nested object
|
||||
elif arr_str[i] == "{":
|
||||
depth = 1
|
||||
obj_start = i + 1
|
||||
i += 1
|
||||
while i < n and depth > 0:
|
||||
if arr_str[i:].startswith(STRING_DELIM):
|
||||
i += len(STRING_DELIM)
|
||||
nd = arr_str.find(STRING_DELIM, i)
|
||||
i = nd + len(STRING_DELIM) if nd != -1 else n
|
||||
continue
|
||||
if arr_str[i] == "{":
|
||||
depth += 1
|
||||
elif arr_str[i] == "}":
|
||||
depth -= 1
|
||||
i += 1
|
||||
items.append(_parse_gemma4_args(arr_str[obj_start : i - 1]))
|
||||
|
||||
# Nested array
|
||||
elif arr_str[i] == "[":
|
||||
depth = 1
|
||||
sub_start = i + 1
|
||||
i += 1
|
||||
while i < n and depth > 0:
|
||||
if arr_str[i] == "[":
|
||||
depth += 1
|
||||
elif arr_str[i] == "]":
|
||||
depth -= 1
|
||||
i += 1
|
||||
items.append(_parse_gemma4_array(arr_str[sub_start : i - 1]))
|
||||
|
||||
# Bare value
|
||||
else:
|
||||
val_start = i
|
||||
while i < n and arr_str[i] not in (",", "]"):
|
||||
i += 1
|
||||
items.append(_parse_gemma4_value(arr_str[val_start:i]))
|
||||
|
||||
return items
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parser
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Gemma4ToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for Google Gemma4 models.
|
||||
|
||||
Handles the Gemma4 function call format::
|
||||
|
||||
<|tool_call>call:func_name{key:<|"|>value<|"|>}<tool_call|>
|
||||
|
||||
Used when ``--enable-auto-tool-choice --tool-call-parser gemma4``
|
||||
are set.
|
||||
|
||||
Streaming strategy: **accumulate-then-parse-then-diff**
|
||||
|
||||
Instead of trying to convert Gemma4's custom format to JSON
|
||||
token-by-token (which fails because Gemma4 uses bare keys, custom
|
||||
delimiters, and structural braces that differ from JSON), this parser:
|
||||
|
||||
1. Accumulates the raw Gemma4 argument string during streaming
|
||||
2. Parses it with ``_parse_gemma4_args()`` into a Python dict
|
||||
3. Converts to JSON with ``json.dumps()``
|
||||
4. Diffs against the previously-streamed JSON string
|
||||
5. Emits only the new JSON fragment as the delta
|
||||
|
||||
This follows the same pattern used by FunctionGemma, Hermes, and Llama
|
||||
tool parsers.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
|
||||
super().__init__(tokenizer, tools)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
# Token strings
|
||||
self.tool_call_start_token = TOOL_CALL_START
|
||||
self.tool_call_end_token = TOOL_CALL_END
|
||||
|
||||
# Token IDs
|
||||
self.tool_call_start_token_id = self.vocab.get(TOOL_CALL_START)
|
||||
self.tool_call_end_token_id = self.vocab.get(TOOL_CALL_END)
|
||||
|
||||
if self.tool_call_start_token_id is None:
|
||||
raise RuntimeError(
|
||||
"Gemma4 ToolParser could not locate the tool call start "
|
||||
f"token '{TOOL_CALL_START}' in the tokenizer!"
|
||||
)
|
||||
|
||||
# Regex for non-streaming: extract complete tool calls.
|
||||
# Supports function names with letters, digits, underscores,
|
||||
# hyphens, and dots (e.g. "get-weather", "module.func").
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<\|tool_call>call:([\w\-\.]+)\{(.*?)\}<tool_call\|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
# Streaming state — reset per-request via _reset_streaming_state()
|
||||
self._reset_streaming_state()
|
||||
|
||||
# Delta buffer for handling multi-token special sequences
|
||||
self.buffered_delta_text = ""
|
||||
|
||||
def _reset_streaming_state(self) -> None:
|
||||
"""Reset all streaming state for a new request."""
|
||||
self.current_tool_id = -1
|
||||
self.current_tool_name_sent = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest | ResponsesRequest
|
||||
) -> ChatCompletionRequest | ResponsesRequest:
|
||||
request = super().adjust_request(request)
|
||||
if (
|
||||
isinstance(request, ChatCompletionRequest)
|
||||
and request.tools
|
||||
and request.tool_choice != "none"
|
||||
):
|
||||
# Don't skip special tokens — <|tool_call> etc. are needed
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Delta buffering for multi-token special sequences
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _buffer_delta_text(self, delta_text: str) -> str:
|
||||
"""Buffer incoming delta text to handle multi-token special sequences.
|
||||
|
||||
Accumulates partial tokens that could be the start of
|
||||
``<|tool_call>`` or ``<tool_call|>`` and only flushes them
|
||||
when the complete sequence is recognized or the sequence breaks.
|
||||
|
||||
This prevents partial special tokens (e.g., ``<|tool``) from being
|
||||
emitted prematurely as content text.
|
||||
"""
|
||||
combined = self.buffered_delta_text + delta_text
|
||||
|
||||
# Check if combined ends with a complete special token
|
||||
if combined.endswith(TOOL_CALL_START) or combined.endswith(TOOL_CALL_END):
|
||||
self.buffered_delta_text = ""
|
||||
return combined
|
||||
|
||||
# Check if combined ends with a partial prefix of a special token
|
||||
for tag in [TOOL_CALL_START, TOOL_CALL_END]:
|
||||
for i in range(1, len(tag)):
|
||||
if combined.endswith(tag[:i]):
|
||||
self.buffered_delta_text = combined[-i:]
|
||||
return combined[:-i]
|
||||
|
||||
# No partial match — flush everything
|
||||
self.buffered_delta_text = ""
|
||||
return combined
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Non-streaming extraction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
if self.tool_call_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
try:
|
||||
matches = self.tool_call_regex.findall(model_output)
|
||||
if not matches:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
tool_calls: list[ToolCall] = []
|
||||
for func_name, args_str in matches:
|
||||
arguments = _parse_gemma4_args(args_str)
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=func_name,
|
||||
arguments=json.dumps(arguments, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Content = text before first tool call (if any)
|
||||
content_end = model_output.find(self.tool_call_start_token)
|
||||
content = model_output[:content_end].strip() if content_end > 0 else None
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error extracting tool calls from Gemma4 response")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Streaming extraction — accumulate-then-parse-then-diff
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
# Buffer delta text to handle multi-token special sequences
|
||||
delta_text = self._buffer_delta_text(delta_text)
|
||||
# Reconstruct current_text after buffering to stay in sync
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
# If no tool call token seen yet, emit as content
|
||||
if self.tool_call_start_token not in current_text:
|
||||
if delta_text:
|
||||
return DeltaMessage(content=delta_text)
|
||||
return None
|
||||
|
||||
try:
|
||||
return self._extract_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error in Gemma4 streaming tool call extraction")
|
||||
return None
|
||||
|
||||
def _extract_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
) -> DeltaMessage | None:
|
||||
"""Tag-counting streaming parser.
|
||||
|
||||
Uses the proven approach from FunctionGemma/Hermes: count start/end
|
||||
tags in previous vs current text to determine phase, then
|
||||
accumulate-parse-diff for arguments.
|
||||
|
||||
Format: ``<|tool_call>call:name{args}<tool_call|>``
|
||||
"""
|
||||
start_count = current_text.count(self.tool_call_start_token)
|
||||
end_count = current_text.count(self.tool_call_end_token)
|
||||
prev_start_count = previous_text.count(self.tool_call_start_token)
|
||||
prev_end_count = previous_text.count(self.tool_call_end_token)
|
||||
|
||||
# Case 1: Not inside any tool call — emit as content
|
||||
if (
|
||||
start_count == end_count
|
||||
and prev_end_count == end_count
|
||||
and self.tool_call_end_token not in delta_text
|
||||
):
|
||||
if delta_text:
|
||||
return DeltaMessage(content=delta_text)
|
||||
return None
|
||||
|
||||
# Case 2: Starting a new tool call
|
||||
if start_count > prev_start_count and start_count > end_count:
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
self.prev_tool_call_arr.append({})
|
||||
logger.debug("Starting new tool call %d", self.current_tool_id)
|
||||
# Don't return yet — fall through to try parsing if there's
|
||||
# content after <|tool_call> in this same delta
|
||||
# (but usually it's just the token itself, so return None)
|
||||
if len(delta_text) <= len(self.tool_call_start_token):
|
||||
return None
|
||||
|
||||
# Case 3: Tool call just ended
|
||||
if end_count > prev_end_count:
|
||||
return self._handle_tool_call_end(current_text)
|
||||
|
||||
# Case 4: In the middle of a tool call — parse partial content
|
||||
if start_count > end_count:
|
||||
return self._handle_tool_call_middle(current_text)
|
||||
|
||||
# Default: generate text outside tool calls
|
||||
if delta_text:
|
||||
text = delta_text.replace(self.tool_call_start_token, "")
|
||||
text = text.replace(self.tool_call_end_token, "")
|
||||
if text:
|
||||
return DeltaMessage(content=text)
|
||||
return None
|
||||
|
||||
def _extract_partial_call(self, current_text: str) -> tuple[str | None, str]:
|
||||
"""Extract function name and raw argument string from partial text.
|
||||
|
||||
Returns (func_name, raw_args_str) or (None, "") if not parseable yet.
|
||||
"""
|
||||
# Get the text after the last <|tool_call> token
|
||||
last_start = current_text.rfind(self.tool_call_start_token)
|
||||
if last_start == -1:
|
||||
return None, ""
|
||||
|
||||
partial_call = current_text[last_start + len(self.tool_call_start_token) :]
|
||||
|
||||
# Strip end token if present
|
||||
if self.tool_call_end_token in partial_call:
|
||||
partial_call = partial_call.split(self.tool_call_end_token)[0]
|
||||
|
||||
# Expect "call:name{args...}" or "call:name{args...}"
|
||||
if not partial_call.startswith("call:"):
|
||||
return None, ""
|
||||
|
||||
func_part = partial_call[5:] # skip "call:"
|
||||
|
||||
if "{" not in func_part:
|
||||
# Still accumulating function name, not ready yet
|
||||
return None, ""
|
||||
|
||||
func_name, _, args_part = func_part.partition("{")
|
||||
func_name = func_name.strip()
|
||||
|
||||
# Strip trailing '}' if present (Gemma4 structural brace)
|
||||
if args_part.endswith("}"):
|
||||
args_part = args_part[:-1]
|
||||
|
||||
return func_name, args_part
|
||||
|
||||
def _handle_tool_call_middle(self, current_text: str) -> DeltaMessage | None:
|
||||
"""Handle streaming when we're inside an active tool call.
|
||||
|
||||
Accumulates the raw Gemma4 arguments, parses them into JSON, and
|
||||
diffs against the previously-streamed JSON to emit only the new
|
||||
fragment.
|
||||
"""
|
||||
func_name, args_part = self._extract_partial_call(current_text)
|
||||
|
||||
if func_name is None:
|
||||
return None
|
||||
|
||||
# Step 1: Send function name (once)
|
||||
if not self.current_tool_name_sent and func_name:
|
||||
self.current_tool_name_sent = True
|
||||
self.prev_tool_call_arr[self.current_tool_id] = {
|
||||
"name": func_name,
|
||||
"arguments": {},
|
||||
}
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=func_name,
|
||||
arguments="",
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Step 2: Parse and diff arguments
|
||||
if self.current_tool_name_sent and args_part:
|
||||
return self._emit_argument_diff(args_part)
|
||||
|
||||
return None
|
||||
|
||||
def _handle_tool_call_end(self, current_text: str) -> DeltaMessage | None:
|
||||
"""Handle streaming when a tool call has just completed.
|
||||
|
||||
Performs a final parse of the complete tool call and flushes
|
||||
any remaining un-streamed argument fragments.
|
||||
"""
|
||||
if self.current_tool_id < 0 or self.current_tool_id >= len(
|
||||
self.prev_tool_call_arr
|
||||
):
|
||||
logger.debug(
|
||||
"Tool call end detected but no active tool call (current_tool_id=%d)",
|
||||
self.current_tool_id,
|
||||
)
|
||||
return None
|
||||
|
||||
# Parse the complete tool call using regex for accuracy
|
||||
all_matches = self.tool_call_regex.findall(current_text)
|
||||
if self.current_tool_id < len(all_matches):
|
||||
_, args_str = all_matches[self.current_tool_id]
|
||||
final_args = _parse_gemma4_args(args_str)
|
||||
final_args_json = json.dumps(final_args, ensure_ascii=False)
|
||||
|
||||
prev_streamed = self.streamed_args_for_tool[self.current_tool_id]
|
||||
if len(final_args_json) > len(prev_streamed):
|
||||
diff = final_args_json[len(prev_streamed) :]
|
||||
self.streamed_args_for_tool[self.current_tool_id] = final_args_json
|
||||
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = final_args
|
||||
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(arguments=diff).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _emit_argument_diff(self, raw_args_str: str) -> DeltaMessage | None:
|
||||
"""Parse raw Gemma4 arguments, convert to JSON, diff, and emit.
|
||||
|
||||
This is the core of the accumulate-then-parse-then-diff strategy:
|
||||
1. Parse ``raw_args_str`` with ``_parse_gemma4_args()``
|
||||
2. Convert to JSON string with ``json.dumps()``
|
||||
3. Withhold trailing closing characters (``"}``) that may move
|
||||
as more tokens arrive
|
||||
4. Diff against previously streamed JSON and emit only new chars
|
||||
|
||||
**Why withholding is necessary:**
|
||||
|
||||
Gemma4's custom format produces *structurally incomplete* JSON
|
||||
during streaming. For example, when ``<|"|>Paris`` arrives
|
||||
without a closing delimiter, ``_parse_gemma4_args`` treats it
|
||||
as a complete value and produces ``{"location": "Paris"}``. But
|
||||
when ``, France<|"|>`` arrives next, the JSON becomes
|
||||
``{"location": "Paris, France"}``. If we had sent the closing
|
||||
``"}`` from the first parse, the concatenated client output
|
||||
would be ``{"location": "Paris"}France"}``, which is garbage.
|
||||
|
||||
The solution: **never send trailing closing chars during
|
||||
streaming**. They get flushed by ``_handle_tool_call_end()``
|
||||
when the ``<tool_call|>`` end marker arrives.
|
||||
|
||||
Args:
|
||||
raw_args_str: The raw Gemma4 argument text accumulated so far
|
||||
(without the surrounding ``{`` ``}``).
|
||||
|
||||
Returns:
|
||||
DeltaMessage with the argument diff, or None if no new content.
|
||||
"""
|
||||
try:
|
||||
current_args = _parse_gemma4_args(raw_args_str)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Could not parse partial Gemma4 args yet: %s",
|
||||
raw_args_str[:100],
|
||||
)
|
||||
return None
|
||||
|
||||
if not current_args:
|
||||
return None
|
||||
|
||||
current_args_json = json.dumps(current_args, ensure_ascii=False)
|
||||
|
||||
# Withhold trailing closing characters that may shift as more
|
||||
# tokens arrive. Strip trailing '}', '"', and ']' sequences
|
||||
# to get the "safe prefix".
|
||||
safe_json = current_args_json
|
||||
while safe_json and safe_json[-1] in ("}", '"', "]"):
|
||||
safe_json = safe_json[:-1]
|
||||
|
||||
prev_streamed = self.streamed_args_for_tool[self.current_tool_id]
|
||||
|
||||
if not safe_json or safe_json == prev_streamed:
|
||||
return None
|
||||
|
||||
# Use find_common_prefix to handle cases where the value changed
|
||||
# structurally (e.g., a string grew).
|
||||
if prev_streamed:
|
||||
prefix = find_common_prefix(prev_streamed, safe_json)
|
||||
sent_len = len(prev_streamed)
|
||||
prefix_len = len(prefix)
|
||||
|
||||
if prefix_len < sent_len:
|
||||
# Structure changed — we sent too much. Truncate our
|
||||
# tracking to the common prefix and wait for the final
|
||||
# flush in _handle_tool_call_end.
|
||||
self.streamed_args_for_tool[self.current_tool_id] = prefix
|
||||
return None
|
||||
|
||||
# Stream the new stable portion
|
||||
diff = safe_json[sent_len:]
|
||||
else:
|
||||
# First emission
|
||||
diff = safe_json
|
||||
|
||||
if diff:
|
||||
self.streamed_args_for_tool[self.current_tool_id] = safe_json
|
||||
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = current_args
|
||||
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(arguments=diff).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return None
|
||||
183
vllm/tool_parsers/gemma4_utils.py
Normal file
183
vllm/tool_parsers/gemma4_utils.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
|
||||
"""Gemma4 tool call parsing utilities for offline inference.
|
||||
|
||||
Standalone functions that parse decoded model text to extract tool calls
|
||||
from Gemma4 models. These are pure-Python utilities with zero heavy
|
||||
dependencies — they work on raw decoded strings from any inference
|
||||
backend (vLLM, HuggingFace, TGI, etc.).
|
||||
|
||||
For the OpenAI-compatible API server tool parser (streaming +
|
||||
non-streaming), see ``vllm.tool_parsers.gemma4_tool_parser``.
|
||||
For thinking/reasoning output parsing, see
|
||||
``vllm.reasoning.gemma4_utils``.
|
||||
|
||||
Usage with vLLM offline inference::
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.tool_parsers.gemma4_utils import (
|
||||
parse_tool_calls,
|
||||
has_tool_response_tag,
|
||||
)
|
||||
|
||||
llm = LLM(model="google/gemma-4-it")
|
||||
outputs = llm.generate(prompt, SamplingParams(...))
|
||||
text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False)
|
||||
|
||||
# Extract tool calls
|
||||
tool_calls = parse_tool_calls(text)
|
||||
for tc in tool_calls:
|
||||
print(f"{tc['name']}({tc['arguments']})")
|
||||
|
||||
Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users
|
||||
do not need a transformers dependency for output parsing.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import regex as re
|
||||
|
||||
# Tool call delimiter tokens as they appear in decoded text.
|
||||
# Standard format: <|tool_call>call:name{args}<tool_call|>
|
||||
_TOOL_CALL_START_TAG = "<|tool_call>"
|
||||
_TOOL_CALL_END_TAG = "<tool_call|>"
|
||||
_TOOL_RESPONSE_START_TAG = "<|tool_response>"
|
||||
|
||||
# Gemma4 escape token as it appears in decoded text.
|
||||
_ESCAPE_TOKEN = '<|"|>'
|
||||
|
||||
|
||||
def _parse_tool_arguments(args_str: str) -> dict[str, str]:
|
||||
"""Parse tool call arguments from the Gemma4 compact format.
|
||||
|
||||
Handles the ``key:<|"|>value<|"|>`` format used by Gemma4, with fallback
|
||||
to heuristic key-value extraction. Also tolerates the slightly different
|
||||
``key: "value"`` format (space + plain quotes) that some chat templates
|
||||
produce.
|
||||
|
||||
Args:
|
||||
args_str: Raw argument string from inside ``call:name{...}``.
|
||||
|
||||
Returns:
|
||||
Dictionary of argument name → value.
|
||||
"""
|
||||
if not args_str or not args_str.strip():
|
||||
return {}
|
||||
|
||||
# Replace Gemma4 escape tokens with standard quotes.
|
||||
cleaned = args_str.replace(_ESCAPE_TOKEN, '"')
|
||||
|
||||
# Try JSON parsing first (handles nested values, arrays, etc.).
|
||||
try:
|
||||
parsed = json.loads("{" + cleaned + "}")
|
||||
# Ensure all values are strings for consistency.
|
||||
return {k: str(v) if not isinstance(v, str) else v for k, v in parsed.items()}
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# Fallback: extract key:"value" pairs (allow optional space after colon).
|
||||
arguments = {}
|
||||
for key, value in re.findall(r'(\w+):\s*"([^"]*)"', cleaned):
|
||||
arguments[key] = value
|
||||
|
||||
if not arguments:
|
||||
# Last resort: extract key:value pairs (unquoted).
|
||||
for key, value in re.findall(r"(\w+):\s*([^,}]+)", args_str):
|
||||
arguments[key] = value.strip().strip('"').replace(_ESCAPE_TOKEN, "")
|
||||
|
||||
return arguments
|
||||
|
||||
|
||||
def parse_tool_calls(text: str, *, strict: bool = False) -> list[dict]:
|
||||
"""Parse tool calls from decoded Gemma4 model output.
|
||||
|
||||
Uses a tiered parsing strategy to handle known output variations in
|
||||
Gemma4 models, which may emit
|
||||
non-standard tool call formats.
|
||||
|
||||
Parsing tiers:
|
||||
1. **Standard**: ``<|tool_call>call:name{args}<tool_call|>``
|
||||
(special token IDs 48/49 in decoded text)
|
||||
2. **Fallback** (when ``strict=False``): bare ``call:name{args}``
|
||||
patterns, including ``<call>name{args}`` (fragmented tokens from
|
||||
multimodal inputs)
|
||||
|
||||
Args:
|
||||
text: Decoded model output text (from ``tokenizer.decode(...,
|
||||
skip_special_tokens=False)``).
|
||||
strict: If ``True``, only match the standard ``<|tool_call>`` format.
|
||||
If ``False`` (default), also try fallback patterns for
|
||||
known Gemma4 output variations.
|
||||
|
||||
Returns:
|
||||
A list of dicts, each with keys:
|
||||
- ``"name"``: The tool function name (e.g. ``"get_weather"``).
|
||||
- ``"arguments"``: A dict of argument name → value.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from vllm.tool_parsers.gemma4_utils import parse_tool_calls
|
||||
>>> output = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
||||
>>> tool_calls = parse_tool_calls(output)
|
||||
>>> for tc in tool_calls:
|
||||
... print(f"Call: {tc['name']}({tc['arguments']})")
|
||||
"""
|
||||
results = []
|
||||
|
||||
# Tier 1: Standard format with special tokens.
|
||||
# <|tool_call>call:name{args}<tool_call|>
|
||||
# Note: Some Gemma4 models emit <turn|> instead of <tool_call|>.
|
||||
standard_pattern = r"<\|tool_call\>call:(\w+)\{(.*?)\}(?:<tool_call\|>|<turn\|>)"
|
||||
for match in re.finditer(standard_pattern, text, re.DOTALL):
|
||||
name, args_str = match.group(1), match.group(2)
|
||||
results.append(
|
||||
{
|
||||
"name": name,
|
||||
"arguments": _parse_tool_arguments(args_str),
|
||||
}
|
||||
)
|
||||
|
||||
if results or strict:
|
||||
return results
|
||||
|
||||
# Tier 2: Fallback for known Gemma4 output variations.
|
||||
# Matches: <call>name{args}, call:name{args}, or bare call:name{args}<eos>
|
||||
fallback_pattern = r"(?:<call>|(?:^|\s)call:)(\w+)\{(.*?)\}"
|
||||
for match in re.finditer(fallback_pattern, text, re.DOTALL):
|
||||
name, args_str = match.group(1), match.group(2)
|
||||
results.append(
|
||||
{
|
||||
"name": name,
|
||||
"arguments": _parse_tool_arguments(args_str),
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def has_tool_response_tag(text: str) -> bool:
|
||||
"""Check if model output properly ends with a tool response tag.
|
||||
|
||||
Some Gemma4 models sometimes emit ``<eos>`` instead of
|
||||
``<|tool_response>`` after a tool call. This helper detects
|
||||
whether the model used the proper termination, so callers can
|
||||
decide whether to inject ``<|tool_response>`` into the next prompt.
|
||||
|
||||
Args:
|
||||
text: Decoded model output text.
|
||||
|
||||
Returns:
|
||||
``True`` if the output ends with ``<|tool_response>``
|
||||
(proper behavior), ``False`` otherwise.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from vllm.tool_parsers.gemma4_utils import has_tool_response_tag
|
||||
>>> if not has_tool_response_tag(model_output):
|
||||
... # Model used <eos> instead — inject <|tool_response> manually
|
||||
... next_prompt = "<|tool_response>" + tool_result
|
||||
"""
|
||||
stripped = text.rstrip()
|
||||
return stripped.endswith(_TOOL_RESPONSE_START_TAG)
|
||||
@@ -448,6 +448,16 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
|
||||
return getattr(self.hf_text_config, "num_nextn_predict_layers", 1)
|
||||
|
||||
|
||||
class Gemma4ModelArchConfigConvertor(ModelArchConfigConvertorBase):
|
||||
def get_head_size(self) -> int:
|
||||
# Gemma4 uses dual head dimensions: head_dim (sliding attention)
|
||||
# and global_head_dim (full attention). Return the largest so
|
||||
# that attention backends allocate buffers large enough for both.
|
||||
head_dim = getattr(self.hf_text_config, "head_dim", 0)
|
||||
global_head_dim = getattr(self.hf_text_config, "global_head_dim", 0)
|
||||
return max(head_dim, global_head_dim) or super().get_head_size()
|
||||
|
||||
|
||||
# hf_config.model_type -> convertor class
|
||||
MODEL_ARCH_CONFIG_CONVERTORS = {
|
||||
"cohere_asr": CohereAsrModelArchConfigConvertor,
|
||||
@@ -471,4 +481,6 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
|
||||
"ernie_mtp": ErnieMTPModelArchConfigConvertor,
|
||||
"pangu_ultra_moe_mtp": PanguUltraMoeMTPModelArchConfigConvertor,
|
||||
"longcat_flash_mtp": LongCatFlashMTPModelArchConfigConvertor,
|
||||
"gemma4": Gemma4ModelArchConfigConvertor,
|
||||
"gemma4_text": Gemma4ModelArchConfigConvertor,
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@@ -22,7 +23,6 @@ from vllm.v1.attention.backend import (
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
split_decodes_and_prefills,
|
||||
split_prefill_chunks,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.cp_utils import get_total_cp_world_size
|
||||
@@ -30,6 +30,55 @@ from vllm.v1.worker.cp_utils import get_total_cp_world_size
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def split_indexer_prefill_chunks(
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
query_lens_cpu: torch.Tensor,
|
||||
workspace_size: int,
|
||||
max_logits_bytes: int,
|
||||
request_offset: int = 0,
|
||||
) -> list[tuple[slice, slice]]:
|
||||
"""
|
||||
Split prefill requests into chunks for the sparse indexer, respecting:
|
||||
- N constraint: total_seq_lens <= workspace_size (existing O(N) workspace)
|
||||
- Logits constraint: M * N * 4 <= max_logits_bytes
|
||||
|
||||
When a single request-level chunk still exceeds the logits budget,
|
||||
sub-chunks on the query dimension (M) to bound peak memory.
|
||||
|
||||
Returns list of (req_slice, query_slice) tuples.
|
||||
"""
|
||||
chunks: list[tuple[slice, slice]] = []
|
||||
n = len(seq_lens_cpu)
|
||||
max_logits_elems = max_logits_bytes // 4
|
||||
end = 0
|
||||
|
||||
while end < n:
|
||||
start, chunk_m, chunk_n = end, 0, 0
|
||||
|
||||
while end < n:
|
||||
q, s = query_lens_cpu[end].item(), seq_lens_cpu[end].item()
|
||||
new_m, new_n = chunk_m + q, chunk_n + s
|
||||
if new_n <= workspace_size and new_m * new_n <= max_logits_elems:
|
||||
chunk_m, chunk_n = new_m, new_n
|
||||
end += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# A single request can exceed the budget, requiring sub-chunking
|
||||
# on the query dimension.
|
||||
if end == start:
|
||||
chunk_m, chunk_n = query_lens_cpu[end].item(), seq_lens_cpu[end].item()
|
||||
end += 1
|
||||
|
||||
req_slice = slice(start + request_offset, end + request_offset)
|
||||
max_q = max(1, max_logits_elems // chunk_n) if chunk_n > 0 else chunk_m
|
||||
for q_off in range(0, chunk_m, max_q):
|
||||
sub_m = min(max_q, chunk_m - q_off)
|
||||
chunks.append((req_slice, slice(q_off, q_off + sub_m)))
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
class DeepseekV32IndexerBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@@ -81,6 +130,7 @@ class DeepseekV32IndexerPrefillChunkMetadata:
|
||||
token_start: int
|
||||
token_end: int
|
||||
num_reqs: int
|
||||
skip_kv_gather: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -271,43 +321,51 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
)
|
||||
|
||||
def build_one_prefill_chunk(
|
||||
self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table
|
||||
):
|
||||
self,
|
||||
req_slice: slice,
|
||||
query_slice: slice,
|
||||
query_start_loc_cpu,
|
||||
seq_lens_cpu,
|
||||
block_table,
|
||||
skip_kv_gather: bool = False,
|
||||
) -> DeepseekV32IndexerPrefillChunkMetadata:
|
||||
prefill_query_start_loc = (
|
||||
query_start_loc_cpu[reqs_start : reqs_end + 1]
|
||||
- query_start_loc_cpu[reqs_start]
|
||||
query_start_loc_cpu[req_slice.start : req_slice.stop + 1]
|
||||
- query_start_loc_cpu[req_slice.start]
|
||||
)
|
||||
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
|
||||
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device
|
||||
prefill_query_start_loc, seq_lens_cpu[req_slice], self.device
|
||||
)
|
||||
token_start = query_start_loc_cpu[req_slice.start].item()
|
||||
total_seq_lens = seq_lens_cpu[req_slice].sum()
|
||||
num_reqs = req_slice.stop - req_slice.start
|
||||
seq_idx = torch.arange(0, num_reqs, dtype=torch.int32)
|
||||
token_to_seq = torch.repeat_interleave(seq_idx, seq_lens_cpu[req_slice]).to(
|
||||
self.device
|
||||
)
|
||||
token_start = query_start_loc_cpu[reqs_start].item()
|
||||
token_end = query_start_loc_cpu[reqs_end].item()
|
||||
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
|
||||
seq_idx = torch.arange(0, reqs_end - reqs_start, dtype=torch.int32)
|
||||
token_to_seq = torch.repeat_interleave(
|
||||
seq_idx, seq_lens_cpu[reqs_start:reqs_end]
|
||||
).to(self.device)
|
||||
assert total_seq_lens <= self.max_prefill_buffer_size
|
||||
cu_seq_lens = (
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=torch.int32),
|
||||
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0),
|
||||
seq_lens_cpu[req_slice].cumsum(dim=0),
|
||||
]
|
||||
)
|
||||
.to(torch.int32)
|
||||
.to(self.device)
|
||||
)
|
||||
|
||||
return DeepseekV32IndexerPrefillChunkMetadata(
|
||||
cu_seqlen_ks=cu_seqlen_ks,
|
||||
cu_seqlen_ke=cu_seqlen_ke,
|
||||
cu_seqlen_ks=cu_seqlen_ks[query_slice],
|
||||
cu_seqlen_ke=cu_seqlen_ke[query_slice],
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
token_to_seq=token_to_seq,
|
||||
total_seq_lens=total_seq_lens,
|
||||
block_table=block_table[reqs_start:reqs_end],
|
||||
token_start=token_start,
|
||||
token_end=token_end,
|
||||
num_reqs=reqs_end - reqs_start,
|
||||
block_table=block_table[req_slice],
|
||||
token_start=token_start + query_slice.start,
|
||||
token_end=token_start + query_slice.stop,
|
||||
num_reqs=num_reqs,
|
||||
skip_kv_gather=skip_kv_gather,
|
||||
)
|
||||
|
||||
def build(
|
||||
@@ -333,20 +391,27 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
chunk_seq_ids = split_prefill_chunks(
|
||||
prefill_query_lens_cpu = torch.diff(
|
||||
query_start_loc_cpu[num_decodes : num_decodes + num_prefills + 1]
|
||||
)
|
||||
max_logits_bytes = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024
|
||||
chunk_specs = split_indexer_prefill_chunks(
|
||||
common_attn_metadata.seq_lens_cpu[num_decodes:],
|
||||
prefill_query_lens_cpu,
|
||||
self.max_prefill_buffer_size,
|
||||
max_logits_bytes,
|
||||
request_offset=num_decodes,
|
||||
)
|
||||
chunks = [
|
||||
self.build_one_prefill_chunk(
|
||||
reqs_start,
|
||||
reqs_end,
|
||||
req_slice,
|
||||
query_slice,
|
||||
query_start_loc_cpu,
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
common_attn_metadata.block_table_tensor,
|
||||
skip_kv_gather=query_slice.start > 0,
|
||||
)
|
||||
for reqs_start, reqs_end in chunk_seq_ids
|
||||
for req_slice, query_slice in chunk_specs
|
||||
]
|
||||
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
|
||||
chunks=chunks,
|
||||
|
||||
@@ -234,6 +234,13 @@ class Scheduler(SchedulerInterface):
|
||||
hash_block_size=self.block_size,
|
||||
metrics_collector=self.kv_metrics_collector,
|
||||
)
|
||||
# Bind GPU block pool to the KV connector. This must happen after
|
||||
# kv_cache_manager is constructed so block_pool is available.
|
||||
if self.connector is not None and hasattr(
|
||||
self.connector, "bind_gpu_block_pool"
|
||||
):
|
||||
self.connector.bind_gpu_block_pool(self.kv_cache_manager.block_pool)
|
||||
|
||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
|
||||
self.scheduler_reserve_full_isl = (
|
||||
|
||||
@@ -281,8 +281,16 @@ class PromptTokenStats:
|
||||
|
||||
self.computed += prompt_len - num_cached_tokens
|
||||
self.external_kv_transfer += num_external_computed_tokens
|
||||
self.local_cache_hit += (
|
||||
num_cached_tokens + recomputed - num_external_computed_tokens
|
||||
# FIXME(yifan): local_cache_hit can go negative after preemption.
|
||||
# num_cached_tokens is a one-time snapshot from first scheduling and
|
||||
# is never reset on preemption, while num_external_computed_tokens is
|
||||
# overwritten on re-scheduling. If CPU offload finds more tokens on
|
||||
# the second pass than the original total, the subtraction underflows.
|
||||
# A fundamental fix is to track the first-time num_external_computed_tokens
|
||||
# as a separate metric rather than reusing num_external_computed_tokens
|
||||
# for metric directly.
|
||||
self.local_cache_hit += max(
|
||||
0, (num_cached_tokens + recomputed - num_external_computed_tokens)
|
||||
)
|
||||
self.cached_tokens += num_cached_tokens
|
||||
self.recomputed_tokens += recomputed
|
||||
|
||||
@@ -303,10 +303,12 @@ class ThinkingTokenBudgetLogitsProcessor(LogitsProcessor):
|
||||
# Check if thinking is enabled
|
||||
self.is_enabled = reasoning_config is not None
|
||||
|
||||
self.think_start_token_ids = getattr(
|
||||
reasoning_config, "think_start_token_ids", []
|
||||
self.reasoning_start_token_ids = getattr(
|
||||
reasoning_config, "reasoning_start_token_ids", []
|
||||
)
|
||||
self.reasoning_end_token_ids = getattr(
|
||||
reasoning_config, "reasoning_end_token_ids", []
|
||||
)
|
||||
self.think_end_token_ids = getattr(reasoning_config, "think_end_token_ids", [])
|
||||
|
||||
self.pin_memory = is_pin_memory
|
||||
self.device = device
|
||||
@@ -357,15 +359,15 @@ class ThinkingTokenBudgetLogitsProcessor(LogitsProcessor):
|
||||
think_count = 0
|
||||
else:
|
||||
last_start = self._find_last_sequence_index(
|
||||
prompt_tok_ids, self.think_start_token_ids
|
||||
prompt_tok_ids, self.reasoning_start_token_ids
|
||||
)
|
||||
last_end = self._find_last_sequence_index(
|
||||
prompt_tok_ids, self.think_end_token_ids
|
||||
prompt_tok_ids, self.reasoning_end_token_ids
|
||||
)
|
||||
in_think = last_start > last_end
|
||||
if in_think:
|
||||
think_count = len(prompt_tok_ids) - (
|
||||
last_start + len(self.think_start_token_ids)
|
||||
last_start + len(self.reasoning_start_token_ids)
|
||||
)
|
||||
else:
|
||||
think_count = 0
|
||||
@@ -405,8 +407,8 @@ class ThinkingTokenBudgetLogitsProcessor(LogitsProcessor):
|
||||
state["prev_output_length"] = current_length
|
||||
|
||||
# Check if new tokens contain think start or end sequences
|
||||
start_len = len(self.think_start_token_ids)
|
||||
end_len = len(self.think_end_token_ids)
|
||||
start_len = len(self.reasoning_start_token_ids)
|
||||
end_len = len(self.reasoning_end_token_ids)
|
||||
|
||||
# Look for think sequences in recent tokens (including boundary)
|
||||
# Check overlapping regions where sequences might span boundaries
|
||||
@@ -415,10 +417,10 @@ class ThinkingTokenBudgetLogitsProcessor(LogitsProcessor):
|
||||
|
||||
# Find any think start/end sequences in recent tokens
|
||||
recent_start_pos = self._find_last_sequence_index(
|
||||
recent_tokens, self.think_start_token_ids
|
||||
recent_tokens, self.reasoning_start_token_ids
|
||||
)
|
||||
recent_end_pos = self._find_last_sequence_index(
|
||||
recent_tokens, self.think_end_token_ids
|
||||
recent_tokens, self.reasoning_end_token_ids
|
||||
)
|
||||
|
||||
# Update state based on recent sequences
|
||||
@@ -469,7 +471,7 @@ class ThinkingTokenBudgetLogitsProcessor(LogitsProcessor):
|
||||
else:
|
||||
# In end mode
|
||||
state["end_count"] += 1
|
||||
if state["end_count"] >= len(self.think_end_token_ids):
|
||||
if state["end_count"] >= len(self.reasoning_end_token_ids):
|
||||
state.update(
|
||||
{
|
||||
"in_end": False,
|
||||
@@ -530,7 +532,9 @@ class ThinkingTokenBudgetLogitsProcessor(LogitsProcessor):
|
||||
state = self._state.get(i)
|
||||
if state and state["in_end"]:
|
||||
self.mask[i] = True
|
||||
self.force_token_ids[i] = self.think_end_token_ids[state["end_count"]]
|
||||
self.force_token_ids[i] = self.reasoning_end_token_ids[
|
||||
state["end_count"]
|
||||
]
|
||||
|
||||
# Check in CPU first not to sync with GPU
|
||||
has_active_thinking = any(
|
||||
|
||||
0
vllm/v1/simple_kv_offload/__init__.py
Normal file
0
vllm/v1/simple_kv_offload/__init__.py
Normal file
97
vllm/v1/simple_kv_offload/copy_backend.py
Normal file
97
vllm/v1/simple_kv_offload/copy_backend.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""DMA copy backend for GPU<->CPU block transfers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import queue
|
||||
import threading
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.simple_kv_offload.cuda_mem_ops import (
|
||||
BatchMemcpyParams,
|
||||
build_params,
|
||||
copy_blocks,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DmaCopyBackend:
|
||||
"""cuMemcpyBatchAsync copy backend (background thread)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._store_params: BatchMemcpyParams | None = None
|
||||
self._load_params: BatchMemcpyParams | None = None
|
||||
self._load_stream: torch.cuda.Stream | None = None
|
||||
self._store_stream: torch.cuda.Stream | None = None
|
||||
self._queue: queue.SimpleQueue | None = None
|
||||
self._thread: threading.Thread | None = None
|
||||
self._shutdown: bool = False
|
||||
|
||||
def init(
|
||||
self,
|
||||
gpu_caches: dict[str, torch.Tensor],
|
||||
cpu_caches: dict[str, torch.Tensor],
|
||||
device: torch.device,
|
||||
load_stream: torch.cuda.Stream,
|
||||
store_stream: torch.cuda.Stream,
|
||||
) -> None:
|
||||
self._load_stream = load_stream
|
||||
self._store_stream = store_stream
|
||||
|
||||
self._store_params = build_params(gpu_caches, cpu_caches, store_stream)
|
||||
self._load_params = build_params(cpu_caches, gpu_caches, load_stream)
|
||||
|
||||
self._queue = queue.SimpleQueue()
|
||||
self._thread = threading.Thread(
|
||||
target=self._copy_loop,
|
||||
args=(self._queue, device, load_stream, store_stream),
|
||||
daemon=True,
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
def launch_copy(
|
||||
self,
|
||||
src_blocks: list[int],
|
||||
dst_blocks: list[int],
|
||||
is_store: bool,
|
||||
event_idx: int,
|
||||
events_list: list[tuple[int, torch.Event]],
|
||||
) -> None:
|
||||
params = self._store_params if is_store else self._load_params
|
||||
assert params is not None and self._queue is not None
|
||||
self._queue.put(
|
||||
(src_blocks, dst_blocks, params, is_store, event_idx, events_list)
|
||||
)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self._shutdown:
|
||||
return
|
||||
self._shutdown = True
|
||||
if self._queue is not None:
|
||||
self._queue.put(None)
|
||||
if self._thread is not None:
|
||||
self._thread.join(timeout=5.0)
|
||||
|
||||
@staticmethod
|
||||
def _copy_loop(
|
||||
q: queue.SimpleQueue,
|
||||
device: torch.device,
|
||||
load_stream: torch.cuda.Stream,
|
||||
store_stream: torch.cuda.Stream,
|
||||
) -> None:
|
||||
current_platform.set_device(device)
|
||||
while True:
|
||||
item = q.get()
|
||||
if item is None:
|
||||
return
|
||||
src_blocks, dst_blocks, params, is_store, event_idx, events_list = item
|
||||
copy_blocks(src_blocks, dst_blocks, params)
|
||||
stream = store_stream if is_store else load_stream
|
||||
event = torch.Event()
|
||||
event.record(stream)
|
||||
events_list.append((event_idx, event))
|
||||
153
vllm/v1/simple_kv_offload/cuda_mem_ops.py
Normal file
153
vllm/v1/simple_kv_offload/cuda_mem_ops.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Low-level CUDA memory helpers: pinning and batch DMA transfers."""
|
||||
|
||||
import ctypes
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def pin_tensor(tensor: torch.Tensor) -> None:
|
||||
"""Pin a CPU tensor via cudaHostRegister.
|
||||
|
||||
This bypasses PyTorch's CUDACachingHostAllocator which rounds
|
||||
every ``pin_memory=True`` allocation up to the next power of 2
|
||||
(e.g. 100 GB becomes 128 GB).
|
||||
"""
|
||||
err = torch.cuda.cudart().cudaHostRegister(tensor.data_ptr(), tensor.nbytes, 0)
|
||||
if err.value != 0:
|
||||
raise RuntimeError(f"cudaHostRegister failed: {err}")
|
||||
|
||||
|
||||
class _CUmemLocation(ctypes.Structure):
|
||||
_fields_ = [("type", ctypes.c_uint), ("id", ctypes.c_int)]
|
||||
|
||||
|
||||
class _CUmemcpyAttributes(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("srcAccessOrder", ctypes.c_uint),
|
||||
("srcLocHint", _CUmemLocation),
|
||||
("dstLocHint", _CUmemLocation),
|
||||
("flags", ctypes.c_uint),
|
||||
]
|
||||
|
||||
|
||||
_BATCH_MEMCPY_FUNC_TYPE = ctypes.CFUNCTYPE(
|
||||
ctypes.c_uint, # CUresult
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_size_t,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_size_t,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
)
|
||||
|
||||
# Resolved lazily on first use.
|
||||
_batch_memcpy_fn: Any = None
|
||||
|
||||
|
||||
def _resolve_batch_memcpy():
|
||||
"""Resolve cuMemcpyBatchAsync via cuGetProcAddress (one-time)."""
|
||||
from cuda.bindings import driver as drv
|
||||
|
||||
err, ptr, _ = drv.cuGetProcAddress(b"cuMemcpyBatchAsync", 12080, 0)
|
||||
if err != drv.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"cuGetProcAddress(cuMemcpyBatchAsync) failed: {err}")
|
||||
return _BATCH_MEMCPY_FUNC_TYPE(ptr)
|
||||
|
||||
|
||||
class BatchMemcpyParams(NamedTuple):
|
||||
src_bases: np.ndarray # [num_layers] uint64 — data_ptr per layer
|
||||
dst_bases: np.ndarray # [num_layers] uint64
|
||||
bpb: np.ndarray # [num_layers] uint64 — bytes per block
|
||||
num_layers: int
|
||||
attrs: _CUmemcpyAttributes
|
||||
attrs_idx: ctypes.c_size_t
|
||||
# NOTE: cuMemcpyBatchAsync_v2() removed fail_idx field, but we use
|
||||
# cuMemcpyBatchAsync() with fail_idx for backward compatibility
|
||||
fail_idx: ctypes.c_size_t
|
||||
stream_handle: int # raw cudaStream_t / CUstream
|
||||
|
||||
|
||||
def build_params(
|
||||
src_caches: dict[str, torch.Tensor],
|
||||
dst_caches: dict[str, torch.Tensor],
|
||||
stream: torch.cuda.Stream,
|
||||
) -> BatchMemcpyParams:
|
||||
global _batch_memcpy_fn
|
||||
if _batch_memcpy_fn is None:
|
||||
_batch_memcpy_fn = _resolve_batch_memcpy()
|
||||
|
||||
assert list(src_caches.keys()) == list(dst_caches.keys())
|
||||
src_tensors = list(src_caches.values())
|
||||
dst_tensors = list(dst_caches.values())
|
||||
|
||||
src_bases, dst_bases, bpb = [], [], []
|
||||
for s, d in zip(src_tensors, dst_tensors):
|
||||
s_bpb = s.stride(0) * s.element_size()
|
||||
assert s_bpb == d.stride(0) * d.element_size()
|
||||
src_bases.append(s.data_ptr())
|
||||
dst_bases.append(d.data_ptr())
|
||||
bpb.append(s_bpb)
|
||||
|
||||
# Refer to https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g6f1ff58e3065df3eb4b573dba77ad31f for details. # noqa: E501
|
||||
attrs = _CUmemcpyAttributes(srcAccessOrder=3) # ANY
|
||||
|
||||
return BatchMemcpyParams(
|
||||
src_bases=np.array(src_bases, dtype=np.uint64),
|
||||
dst_bases=np.array(dst_bases, dtype=np.uint64),
|
||||
bpb=np.array(bpb, dtype=np.uint64),
|
||||
num_layers=len(src_tensors),
|
||||
attrs=attrs,
|
||||
attrs_idx=ctypes.c_size_t(0),
|
||||
fail_idx=ctypes.c_size_t(0),
|
||||
stream_handle=stream.cuda_stream,
|
||||
)
|
||||
|
||||
|
||||
def copy_blocks(
|
||||
src_block_ids: list[int],
|
||||
dst_block_ids: list[int],
|
||||
params: BatchMemcpyParams,
|
||||
) -> None:
|
||||
"""Copy blocks via cuMemcpyBatchAsync."""
|
||||
n = len(src_block_ids)
|
||||
if n == 0:
|
||||
return
|
||||
|
||||
src_ids = np.array(src_block_ids, dtype=np.uint64)
|
||||
dst_ids = np.array(dst_block_ids, dtype=np.uint64)
|
||||
|
||||
src_all = (
|
||||
params.src_bases[:, None] + src_ids[None, :] * params.bpb[:, None]
|
||||
).ravel()
|
||||
dst_all = (
|
||||
params.dst_bases[:, None] + dst_ids[None, :] * params.bpb[:, None]
|
||||
).ravel()
|
||||
sz_all = np.repeat(params.bpb, n)
|
||||
|
||||
total = n * params.num_layers
|
||||
err = _batch_memcpy_fn(
|
||||
dst_all.ctypes.data,
|
||||
src_all.ctypes.data,
|
||||
sz_all.ctypes.data,
|
||||
total,
|
||||
ctypes.addressof(params.attrs),
|
||||
ctypes.byref(params.attrs_idx),
|
||||
1,
|
||||
ctypes.byref(params.fail_idx),
|
||||
params.stream_handle,
|
||||
)
|
||||
if err != 0:
|
||||
raise RuntimeError(
|
||||
f"cuMemcpyBatchAsync failed: err={err} failIdx={params.fail_idx.value}"
|
||||
)
|
||||
739
vllm/v1/simple_kv_offload/manager.py
Normal file
739
vllm/v1/simple_kv_offload/manager.py
Normal file
@@ -0,0 +1,739 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Scheduler-side manager for SimpleCPUOffloadConnector."""
|
||||
|
||||
import contextlib
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_coordinator import (
|
||||
KVCacheCoordinator,
|
||||
get_kv_cache_coordinator,
|
||||
)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
MambaSpec,
|
||||
SlidingWindowSpec,
|
||||
)
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
from vllm.v1.simple_kv_offload.metadata import (
|
||||
SimpleCPUOffloadMetadata,
|
||||
SimpleCPUOffloadWorkerMetadata,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import KVCacheBlock
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransferMeta:
|
||||
gpu_block_ids: list[int]
|
||||
cpu_block_ids: list[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadRequestState:
|
||||
request: "Request"
|
||||
transfer_meta: TransferMeta
|
||||
load_event: int | None = None
|
||||
finished: bool = False
|
||||
|
||||
|
||||
# NOTE: This per-request state is only used in eager mode.
|
||||
@dataclass
|
||||
class StoreRequestState:
|
||||
request: "Request"
|
||||
# Accumulated block IDs from scheduler_output via yield_req_data.
|
||||
block_ids: tuple[list[int], ...]
|
||||
# Per-group cursors tracking how many blocks have been stored/skipped.
|
||||
num_stored_blocks: list[int]
|
||||
store_events: set[int] = field(default_factory=set)
|
||||
finished: bool = False
|
||||
|
||||
|
||||
class SimpleCPUOffloadScheduler:
|
||||
"""Scheduler-side manager for CPU offloading."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_config: "KVCacheConfig | None",
|
||||
cpu_capacity_bytes: int,
|
||||
lazy_offload: bool = False,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.enable_kv_cache_events = (
|
||||
vllm_config.kv_events_config is not None
|
||||
and vllm_config.kv_events_config.enable_kv_cache_events
|
||||
)
|
||||
# NOTE: We use the same block size for both GPU and CPU.
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
# Derive a CPU KVCacheConfig from the GPU config and build a coordinator
|
||||
assert kv_cache_config is not None
|
||||
self.cpu_kv_cache_config = self._derive_cpu_config(
|
||||
kv_cache_config, cpu_capacity_bytes
|
||||
)
|
||||
self.num_cpu_blocks = self.cpu_kv_cache_config.num_blocks
|
||||
# Find the full attention kv group for prefix cache matching.
|
||||
self.fa_gidx = -1
|
||||
for g_idx, g in enumerate(self.cpu_kv_cache_config.kv_cache_groups):
|
||||
if isinstance(g.kv_cache_spec, FullAttentionSpec):
|
||||
self.fa_gidx = g_idx
|
||||
break
|
||||
assert 0 <= self.fa_gidx < len(self.cpu_kv_cache_config.kv_cache_groups)
|
||||
|
||||
logger.info(
|
||||
"SimpleCPUOffloadScheduler: Allocating %d CPU blocks (%.2f GB, mode=%s)",
|
||||
self.num_cpu_blocks,
|
||||
cpu_capacity_bytes / (1024**3),
|
||||
"lazy" if lazy_offload else "eager",
|
||||
)
|
||||
|
||||
# TODO (yifan): maybe need to enable kv_cache_events and metrics_collector here.
|
||||
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||
pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
|
||||
assert dcp_world_size == 1 and pcp_world_size == 1
|
||||
self.cpu_coordinator: KVCacheCoordinator = get_kv_cache_coordinator(
|
||||
kv_cache_config=self.cpu_kv_cache_config,
|
||||
max_model_len=vllm_config.model_config.max_model_len,
|
||||
use_eagle=False,
|
||||
enable_caching=True,
|
||||
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
hash_block_size=self.block_size,
|
||||
)
|
||||
self.cpu_block_pool: BlockPool = self.cpu_coordinator.block_pool
|
||||
|
||||
# GPU block pool reference - bound after scheduler builds kv_cache_manager
|
||||
self._gpu_block_pool: BlockPool | None = None
|
||||
|
||||
# Load metadata
|
||||
self._reqs_to_load: dict[str, LoadRequestState] = {}
|
||||
# Inverse map: load_event_idx -> req_ids. Keyed by load_event_idx because
|
||||
# the worker reports completions by event index, not request id.
|
||||
self._load_event_to_reqs: dict[int, list[str]] = {}
|
||||
|
||||
# Store metadata
|
||||
self._lazy_mode = lazy_offload
|
||||
# Lazy mode: use a cursor to track the last scanned block in the GPU free queue.
|
||||
self._cursor: KVCacheBlock | None = None
|
||||
if self._lazy_mode:
|
||||
self._target_free = self._estimate_lazy_target_blocks(
|
||||
kv_cache_config,
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
)
|
||||
else:
|
||||
self._target_free = 0
|
||||
self._store_event_to_blocks: dict[int, TransferMeta] = {}
|
||||
# Eager mode only
|
||||
self._reqs_to_store: dict[str, StoreRequestState] = {}
|
||||
self._store_event_to_reqs: dict[int, list[str]] = {}
|
||||
|
||||
# Event counters
|
||||
self._load_event_counter: int = 0
|
||||
self._store_event_counter: int = 0
|
||||
|
||||
# For TP/PP: track partial store completions across steps.
|
||||
# Events must be reported by all world_size workers before considered complete.
|
||||
self._expected_worker_count = vllm_config.parallel_config.world_size
|
||||
self._store_event_pending_counts: dict[int, int] = {}
|
||||
|
||||
@staticmethod
|
||||
def _derive_cpu_config(
|
||||
gpu_config: "KVCacheConfig", cpu_capacity_bytes: int
|
||||
) -> "KVCacheConfig":
|
||||
"""Derive a CPU KVCacheConfig from the GPU config.
|
||||
Same kv_cache_groups, num_blocks scaled by CPU/GPU memory ratio."""
|
||||
# Import here to avoid potential circular imports
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig as KVCacheConfigCls
|
||||
from vllm.v1.kv_cache_interface import KVCacheTensor
|
||||
|
||||
assert len(gpu_config.kv_cache_tensors) > 0
|
||||
|
||||
gpu_total_bytes = sum(t.size for t in gpu_config.kv_cache_tensors)
|
||||
num_gpu_blocks = gpu_config.num_blocks
|
||||
num_cpu_blocks = max(1, num_gpu_blocks * cpu_capacity_bytes // gpu_total_bytes)
|
||||
# Create CPU kv_cache_tensors mirroring GPU by scaling size proportionally.
|
||||
cpu_tensors = [
|
||||
KVCacheTensor(
|
||||
size=t.size // num_gpu_blocks * num_cpu_blocks,
|
||||
shared_by=list(t.shared_by),
|
||||
)
|
||||
for t in gpu_config.kv_cache_tensors
|
||||
]
|
||||
|
||||
return KVCacheConfigCls(
|
||||
num_blocks=num_cpu_blocks,
|
||||
kv_cache_tensors=cpu_tensors,
|
||||
kv_cache_groups=gpu_config.kv_cache_groups,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _estimate_lazy_target_blocks(
|
||||
kv_cache_config: "KVCacheConfig", max_num_batched_tokens: int
|
||||
) -> int:
|
||||
"""GPU blocks to keep available (free/offloaded) per step in lazy mode."""
|
||||
WATERMARK_RATIO = 1.0 # Reserve larger space to avoid running out of GPU blocks
|
||||
target = 0
|
||||
for g in kv_cache_config.kv_cache_groups:
|
||||
spec = g.kv_cache_spec
|
||||
if isinstance(spec, MambaSpec):
|
||||
target += 2
|
||||
elif isinstance(spec, SlidingWindowSpec):
|
||||
target += cdiv(spec.sliding_window, spec.block_size) + 1
|
||||
else:
|
||||
target += cdiv(max_num_batched_tokens, spec.block_size)
|
||||
return int(target * (1 + WATERMARK_RATIO))
|
||||
|
||||
def bind_gpu_block_pool(self, gpu_block_pool: BlockPool) -> None:
|
||||
"""Bind GPU block pool so that we can touch blocks during stores.
|
||||
Called by Scheduler after kv_cache_manager is ready."""
|
||||
self._gpu_block_pool = gpu_block_pool
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int | None, bool]:
|
||||
"""Return (num_new_tokens, is_async) from consecutive CPU cache hits."""
|
||||
skipped = num_computed_tokens // self.block_size
|
||||
remaining_hashes = request.block_hashes[skipped:]
|
||||
|
||||
if not remaining_hashes:
|
||||
return 0, False
|
||||
# Must recompute at least the last token, matching the logic in
|
||||
# kv_cache_manager.get_computed_blocks().
|
||||
max_hit_len = request.num_tokens - 1 - num_computed_tokens
|
||||
if max_hit_len <= 0:
|
||||
return 0, False
|
||||
_, hit_length = self.cpu_coordinator.find_longest_cache_hit(
|
||||
remaining_hashes, max_hit_len
|
||||
)
|
||||
|
||||
if hit_length > 0:
|
||||
return hit_length, True
|
||||
return 0, False
|
||||
|
||||
# TODO(yifan): this API now only matches the suffix part of the prefix cache. A more
|
||||
# general API should scan blocks in both GPU and CPU block pool in a single pass.
|
||||
def update_state_after_alloc(
|
||||
self,
|
||||
request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int,
|
||||
) -> None:
|
||||
req_id = request.request_id
|
||||
block_ids_by_group = blocks.get_block_ids()
|
||||
num_groups = len(block_ids_by_group)
|
||||
|
||||
# Store tracking (eager mode only). Register the request;
|
||||
# block IDs are accumulated from scheduler_output in
|
||||
# _prepare_eager_store_specs via yield_req_data.
|
||||
if not self._lazy_mode and req_id not in self._reqs_to_store:
|
||||
self._reqs_to_store[req_id] = StoreRequestState(
|
||||
request=request,
|
||||
block_ids=tuple([] for _ in range(num_groups)),
|
||||
num_stored_blocks=[0] * num_groups,
|
||||
)
|
||||
|
||||
if num_external_tokens == 0:
|
||||
return
|
||||
|
||||
num_blocks_to_load = num_external_tokens // self.block_size
|
||||
assert num_blocks_to_load > 0
|
||||
|
||||
skipped = sum(blk.block_hash is not None for blk in blocks.blocks[self.fa_gidx])
|
||||
num_computed_tokens = skipped * self.block_size
|
||||
hashes_to_load = request.block_hashes[skipped : skipped + num_blocks_to_load]
|
||||
|
||||
# Find CPU cached blocks across all groups.
|
||||
max_hit_len = len(hashes_to_load) * self.block_size
|
||||
cpu_hit_blocks, hit_length = self.cpu_coordinator.find_longest_cache_hit(
|
||||
hashes_to_load, max_hit_len
|
||||
)
|
||||
assert hit_length == num_external_tokens, (
|
||||
f"Expected {num_external_tokens} hit tokens, got {hit_length}"
|
||||
)
|
||||
|
||||
# Build transfer pairs across all groups.
|
||||
total_computed_tokens = num_computed_tokens + num_external_tokens
|
||||
kv_cache_groups = self.cpu_kv_cache_config.kv_cache_groups
|
||||
|
||||
gpu_block_ids: list[int] = []
|
||||
cpu_block_ids: list[int] = []
|
||||
cpu_blocks_to_touch: list[KVCacheBlock] = []
|
||||
|
||||
for g in range(num_groups):
|
||||
cpu_blocks_g = cpu_hit_blocks[g]
|
||||
n_ext_g = len(cpu_blocks_g)
|
||||
if n_ext_g == 0:
|
||||
continue
|
||||
|
||||
# Number of blocks in the computed range for this group.
|
||||
g_block_size = kv_cache_groups[g].kv_cache_spec.block_size
|
||||
n_computed_g = cdiv(total_computed_tokens, g_block_size)
|
||||
|
||||
# Back-trace: ext blocks sit at the tail of the computed range.
|
||||
gpu_ext_start = n_computed_g - n_ext_g
|
||||
group_gpu_ids = block_ids_by_group[g]
|
||||
|
||||
for i, cpu_blk in enumerate(cpu_blocks_g):
|
||||
# Skip null blocks (e.g. sliding window or mamba padding).
|
||||
if cpu_blk.is_null:
|
||||
continue
|
||||
gpu_block_ids.append(group_gpu_ids[gpu_ext_start + i])
|
||||
cpu_block_ids.append(cpu_blk.block_id)
|
||||
cpu_blocks_to_touch.append(cpu_blk)
|
||||
|
||||
# Touch CPU blocks to prevent eviction during async load.
|
||||
self.cpu_block_pool.touch(cpu_blocks_to_touch)
|
||||
|
||||
# Touch GPU blocks to prevent freeing during async load
|
||||
assert self._gpu_block_pool is not None
|
||||
self._gpu_block_pool.touch(
|
||||
[self._gpu_block_pool.blocks[bid] for bid in gpu_block_ids]
|
||||
)
|
||||
|
||||
assert self._reqs_to_load.get(req_id) is None
|
||||
self._reqs_to_load[req_id] = LoadRequestState(
|
||||
request=request, transfer_meta=TransferMeta(gpu_block_ids, cpu_block_ids)
|
||||
)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> SimpleCPUOffloadMetadata:
|
||||
# --- Stores ---
|
||||
store_event = -1
|
||||
store_gpu, store_cpu, store_req_ids = self.prepare_store_specs(scheduler_output)
|
||||
if store_gpu:
|
||||
store_event = self._store_event_counter
|
||||
self._store_event_counter += 1
|
||||
self._store_event_to_blocks[store_event] = TransferMeta(
|
||||
store_gpu, store_cpu
|
||||
)
|
||||
if store_req_ids: # For eager mode only, track req->blocks mapping
|
||||
self._store_event_to_reqs[store_event] = store_req_ids
|
||||
for req_id in store_req_ids:
|
||||
store_state = self._reqs_to_store.get(req_id)
|
||||
if store_state is not None:
|
||||
store_state.store_events.add(store_event)
|
||||
|
||||
# --- Loads ---
|
||||
load_event = -1
|
||||
load_gpu: list[int] = []
|
||||
load_cpu: list[int] = []
|
||||
load_req_ids: list[str] = []
|
||||
for req_id, load_state in self._reqs_to_load.items():
|
||||
if load_state.load_event is not None:
|
||||
continue
|
||||
assert load_state.transfer_meta is not None
|
||||
load_gpu.extend(load_state.transfer_meta.gpu_block_ids)
|
||||
load_cpu.extend(load_state.transfer_meta.cpu_block_ids)
|
||||
load_req_ids.append(req_id)
|
||||
if load_req_ids:
|
||||
load_event = self._load_event_counter
|
||||
self._load_event_counter += 1
|
||||
for req_id in load_req_ids:
|
||||
self._reqs_to_load[req_id].load_event = load_event
|
||||
self._load_event_to_reqs[load_event] = load_req_ids
|
||||
|
||||
result = SimpleCPUOffloadMetadata(
|
||||
load_event=load_event,
|
||||
load_gpu_blocks=load_gpu,
|
||||
load_cpu_blocks=load_cpu,
|
||||
load_event_to_reqs=self._load_event_to_reqs,
|
||||
store_event=store_event,
|
||||
store_gpu_blocks=store_gpu,
|
||||
store_cpu_blocks=store_cpu,
|
||||
need_flush=bool(scheduler_output.preempted_req_ids),
|
||||
)
|
||||
return result
|
||||
|
||||
def prepare_store_specs(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> tuple[list[int], list[int], list[str]]:
|
||||
"""Prepare store specs for the store event."""
|
||||
if self._lazy_mode:
|
||||
return self._prepare_lazy_store_specs()
|
||||
else:
|
||||
return self._prepare_eager_store_specs(scheduler_output)
|
||||
|
||||
def _prepare_lazy_store_specs(
|
||||
self,
|
||||
) -> tuple[list[int], list[int], list[str]]:
|
||||
"""Single-pass cursor walk: offload cached GPU blocks near eviction.
|
||||
|
||||
Walks the GPU free queue from the cursor, counting blocks that are
|
||||
free-or-offloaded (safe for the allocator to evict). Stops when
|
||||
target_free blocks are covered or CPU capacity is reached.
|
||||
"""
|
||||
gpu_pool = self._gpu_block_pool
|
||||
if gpu_pool is None or self._target_free <= 0:
|
||||
return [], [], []
|
||||
|
||||
free_queue = gpu_pool.free_block_queue
|
||||
cpu_pool = self.cpu_block_pool
|
||||
num_cpu_free = cpu_pool.get_num_free_blocks()
|
||||
|
||||
# Validate cursor: stale if block was removed from free queue.
|
||||
if self._cursor is not None and self._cursor.ref_cnt > 0:
|
||||
self._cursor = None
|
||||
|
||||
# Determine start node.
|
||||
if self._cursor is None:
|
||||
node = free_queue.fake_free_list_head.next_free_block
|
||||
else:
|
||||
node = self._cursor.next_free_block
|
||||
|
||||
tail = free_queue.fake_free_list_tail
|
||||
gpu_ids: list[int] = []
|
||||
block_hashes: list[bytes] = []
|
||||
covered = 0
|
||||
last_visited = self._cursor
|
||||
|
||||
while (
|
||||
node is not None
|
||||
and node is not tail
|
||||
and covered < self._target_free
|
||||
and len(gpu_ids) < num_cpu_free
|
||||
):
|
||||
last_visited = node
|
||||
bhash = node.block_hash
|
||||
|
||||
if (
|
||||
bhash is not None
|
||||
and not node.is_null
|
||||
and cpu_pool.cached_block_hash_to_block.get_one_block(bhash) is None
|
||||
):
|
||||
gpu_ids.append(node.block_id)
|
||||
block_hashes.append(bhash)
|
||||
|
||||
covered += 1
|
||||
node = node.next_free_block
|
||||
|
||||
self._cursor = last_visited
|
||||
|
||||
# Batch-allocate CPU blocks and stamp hashes.
|
||||
if gpu_ids:
|
||||
cpu_blocks = cpu_pool.get_new_blocks(len(gpu_ids))
|
||||
cpu_ids = [blk.block_id for blk in cpu_blocks]
|
||||
for cpu_blk, bhash in zip(cpu_blocks, block_hashes): # type: ignore[assignment]
|
||||
cpu_blk._block_hash = bhash # type: ignore[assignment]
|
||||
# Touch GPU blocks to prevent eviction during async copy.
|
||||
gpu_pool.touch([gpu_pool.blocks[bid] for bid in gpu_ids])
|
||||
else:
|
||||
cpu_ids = []
|
||||
|
||||
return gpu_ids, cpu_ids, []
|
||||
|
||||
def _prepare_eager_store_specs(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> tuple[list[int], list[int], list[str]]:
|
||||
"""Identify newly computed blocks to offload from scheduler requests.
|
||||
|
||||
Only considers blocks whose KV data has been **confirmed computed** by
|
||||
the GPU. This means blocks from the current step are NOT stored until the
|
||||
next step. If a request finishes in the same step as its last full block,
|
||||
that block may be missed. (TODO: flush on finish.)
|
||||
|
||||
Returns:
|
||||
(gpu_block_ids, cpu_block_ids, req_ids) for the store event.
|
||||
"""
|
||||
|
||||
merged_gpu_block_ids: list[int] = []
|
||||
merged_cpu_block_ids: list[int] = []
|
||||
req_ids: list[str] = []
|
||||
|
||||
gpu_block_pool = self._gpu_block_pool
|
||||
if gpu_block_pool is None:
|
||||
return [], [], []
|
||||
cpu_block_pool = self.cpu_block_pool
|
||||
num_free = cpu_block_pool.get_num_free_blocks()
|
||||
kv_cache_groups = self.cpu_kv_cache_config.kv_cache_groups
|
||||
num_groups = len(kv_cache_groups)
|
||||
gpu_blocks_this_step: set[int] = set()
|
||||
|
||||
for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
|
||||
state = self._reqs_to_store.get(req_id)
|
||||
if state is None or state.finished:
|
||||
continue
|
||||
|
||||
# Accumulate new block IDs.
|
||||
if preempted:
|
||||
state.block_ids = tuple([] for _ in range(num_groups))
|
||||
state.num_stored_blocks = [0] * num_groups
|
||||
if new_block_id_groups:
|
||||
for g in range(min(num_groups, len(new_block_id_groups))):
|
||||
if new_block_id_groups[g] is not None:
|
||||
state.block_ids[g].extend(new_block_id_groups[g])
|
||||
|
||||
num_new_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0)
|
||||
if num_new_tokens == 0:
|
||||
continue
|
||||
|
||||
block_ids_by_group = state.block_ids
|
||||
if not block_ids_by_group:
|
||||
continue
|
||||
|
||||
# --- Phase 1: Scan blocks, classify as cached vs to-store ---
|
||||
gpu_block_ids: list[int] = []
|
||||
block_hashes_to_store: list[bytes] = []
|
||||
advanced_per_group: list[int] = [0] * num_groups
|
||||
out_of_space = False
|
||||
# Confirmed tokens: KV data written and visible to all streams.
|
||||
req = state.request
|
||||
confirmed_tokens = req.num_computed_tokens - req.num_output_placeholders
|
||||
|
||||
for g in range(num_groups):
|
||||
# FIXME (yifan): handle CPU cache eviction, where
|
||||
# num_stored_blocks can be stale and omit evicted blocks in
|
||||
# the middle of the request.
|
||||
already_stored_g = state.num_stored_blocks[g]
|
||||
group_gpu_ids = block_ids_by_group[g]
|
||||
|
||||
# Cap to blocks with confirmed KV data.
|
||||
g_block_size = kv_cache_groups[g].kv_cache_spec.block_size
|
||||
ready_blocks_g = confirmed_tokens // g_block_size
|
||||
scannable = group_gpu_ids[already_stored_g:ready_blocks_g]
|
||||
|
||||
for gpu_block_id in scannable:
|
||||
gpu_block = gpu_block_pool.blocks[gpu_block_id]
|
||||
if gpu_block.is_null:
|
||||
advanced_per_group[g] += 1
|
||||
continue
|
||||
|
||||
bhash_with_group = gpu_block.block_hash
|
||||
if bhash_with_group is None:
|
||||
break
|
||||
|
||||
# Check if this group's data is already scheduled for store
|
||||
# in this step or already cached in CPU.
|
||||
if (
|
||||
gpu_block_id in gpu_blocks_this_step
|
||||
or cpu_block_pool.cached_block_hash_to_block.get_one_block(
|
||||
bhash_with_group
|
||||
)
|
||||
is not None
|
||||
):
|
||||
advanced_per_group[g] += 1
|
||||
continue
|
||||
|
||||
if num_free <= 0:
|
||||
out_of_space = True
|
||||
break
|
||||
num_free -= 1
|
||||
|
||||
gpu_block_ids.append(gpu_block_id)
|
||||
block_hashes_to_store.append(bhash_with_group)
|
||||
advanced_per_group[g] += 1
|
||||
|
||||
if out_of_space:
|
||||
break
|
||||
|
||||
# --- Phase 2: Batch allocate CPU blocks and stamp hashes ---
|
||||
n_to_alloc = len(gpu_block_ids)
|
||||
if n_to_alloc > 0:
|
||||
cpu_blocks_alloc = cpu_block_pool.get_new_blocks(n_to_alloc)
|
||||
cpu_block_ids = [blk.block_id for blk in cpu_blocks_alloc]
|
||||
for cpu_blk, bhash in zip(cpu_blocks_alloc, block_hashes_to_store):
|
||||
cpu_blk._block_hash = bhash # type: ignore[assignment]
|
||||
else:
|
||||
cpu_block_ids = []
|
||||
|
||||
if cpu_block_ids:
|
||||
req_ids.append(req_id)
|
||||
merged_gpu_block_ids.extend(gpu_block_ids)
|
||||
merged_cpu_block_ids.extend(cpu_block_ids)
|
||||
gpu_blocks_this_step.update(gpu_block_ids)
|
||||
|
||||
# Touch GPU blocks to prevent freeing during async copy
|
||||
gpu_block_pool.touch(
|
||||
[gpu_block_pool.blocks[bid] for bid in gpu_block_ids]
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Request %s: Scheduling store of %d blocks to CPU (%d groups)",
|
||||
req_id,
|
||||
len(cpu_block_ids),
|
||||
num_groups,
|
||||
)
|
||||
|
||||
# Advance per-group cursors (includes cached hits + newly stored)
|
||||
for g in range(num_groups):
|
||||
state.num_stored_blocks[g] += advanced_per_group[g]
|
||||
|
||||
return merged_gpu_block_ids, merged_cpu_block_ids, req_ids
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput) -> None:
|
||||
"""Handle async transfer completions from worker.
|
||||
|
||||
Load completions arrive via finished_recving (real req_ids).
|
||||
Store completions arrive via kv_connector_worker_meta as
|
||||
per-event worker counts. We accumulate across steps and process
|
||||
a store event only when all workers have reported completion.
|
||||
"""
|
||||
# --- Load completions ---
|
||||
for req_id in list(connector_output.finished_recving or []):
|
||||
self._cleanup_load_request(req_id)
|
||||
|
||||
# --- Store completions ---
|
||||
meta = connector_output.kv_connector_worker_meta
|
||||
if not isinstance(meta, SimpleCPUOffloadWorkerMetadata):
|
||||
return
|
||||
for event_idx, count in meta.completed_store_events.items():
|
||||
total = self._store_event_pending_counts.get(event_idx, 0) + count
|
||||
if total >= self._expected_worker_count:
|
||||
self._store_event_pending_counts.pop(event_idx, None)
|
||||
self._process_store_event(event_idx)
|
||||
else:
|
||||
self._store_event_pending_counts[event_idx] = total
|
||||
|
||||
def _process_store_event(self, event_idx: int) -> None:
|
||||
"""Process a fully-completed store event."""
|
||||
transfer = self._store_event_to_blocks.pop(event_idx)
|
||||
self._process_store_completion(transfer.gpu_block_ids, transfer.cpu_block_ids)
|
||||
logger.debug(
|
||||
"Store event %d completed: cached %d blocks to CPU",
|
||||
event_idx,
|
||||
len(transfer.cpu_block_ids),
|
||||
)
|
||||
|
||||
# Eager only: update per-req state
|
||||
if not self._lazy_mode:
|
||||
for req_id in self._store_event_to_reqs.pop(event_idx, []):
|
||||
state = self._reqs_to_store.get(req_id)
|
||||
if state is None:
|
||||
continue
|
||||
state.store_events.discard(event_idx)
|
||||
if state.finished and not state.store_events:
|
||||
self._cleanup_store_request(req_id)
|
||||
|
||||
def _process_store_completion(
|
||||
self, gpu_block_ids: list[int], cpu_block_ids: list[int]
|
||||
) -> None:
|
||||
"""Cache CPU blocks per-group and release GPU refs.
|
||||
|
||||
Block hashes were stamped on CPU blocks at allocation time (in
|
||||
``_prepare_*_store_specs``). Here we just register them in the
|
||||
cache map so they become discoverable by the load path.
|
||||
"""
|
||||
assert len(cpu_block_ids) == len(gpu_block_ids)
|
||||
|
||||
cpu_blocks = [self.cpu_block_pool.blocks[bid] for bid in cpu_block_ids]
|
||||
|
||||
for cpu_block in cpu_blocks:
|
||||
bhash = cpu_block.block_hash
|
||||
assert bhash is not None
|
||||
self.cpu_block_pool.cached_block_hash_to_block.insert(bhash, cpu_block)
|
||||
|
||||
# Free CPU and GPU blocks' ref counts to turn them into prefix cache
|
||||
self.cpu_block_pool.free_blocks(cpu_blocks)
|
||||
assert self._gpu_block_pool is not None
|
||||
self._gpu_block_pool.free_blocks(
|
||||
self._gpu_block_pool.blocks[bid] for bid in gpu_block_ids
|
||||
)
|
||||
|
||||
def has_pending_stores(self) -> bool:
|
||||
"""Return True if there are in-flight store transfers."""
|
||||
return bool(self._store_event_to_blocks)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""Always returns (False, None). GPU blocks are protected by ref_cnt,
|
||||
so the scheduler can free blocks immediately."""
|
||||
req_id = request.request_id
|
||||
|
||||
# Handle load: defer cleanup if load is in-flight
|
||||
load_state = self._reqs_to_load.get(req_id)
|
||||
if load_state is not None:
|
||||
if load_state.load_event is not None:
|
||||
load_state.finished = True # Defer: load in-flight
|
||||
else:
|
||||
self._cleanup_load_request(req_id)
|
||||
|
||||
# Handle store (eager mode only): defer cleanup if stores in-flight
|
||||
if not self._lazy_mode:
|
||||
store_state = self._reqs_to_store.get(req_id)
|
||||
if store_state is not None:
|
||||
if store_state.store_events:
|
||||
store_state.finished = True # Defer: stores in-flight
|
||||
else:
|
||||
self._cleanup_store_request(req_id)
|
||||
|
||||
return False, None
|
||||
|
||||
def request_finished_all_groups(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: tuple[list[int], ...],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
return self.request_finished(request, block_ids=[])
|
||||
|
||||
def _cleanup_load_request(self, req_id: str) -> None:
|
||||
"""Release all load resources for a request.
|
||||
|
||||
Shared between request_finished() and update_connector_output() paths.
|
||||
Removes the request from _reqs_to_load, cleans up event mappings,
|
||||
and frees CPU/GPU touch refs.
|
||||
"""
|
||||
state = self._reqs_to_load.pop(req_id, None)
|
||||
if state is None:
|
||||
return
|
||||
# Remove from load event mapping (only this req, not whole event)
|
||||
if state.load_event is not None:
|
||||
reqs = self._load_event_to_reqs.get(state.load_event)
|
||||
if reqs is not None:
|
||||
with contextlib.suppress(ValueError):
|
||||
reqs.remove(req_id)
|
||||
if not reqs:
|
||||
self._load_event_to_reqs.pop(state.load_event, None)
|
||||
|
||||
if state.transfer_meta is not None:
|
||||
# Free CPU touch refs
|
||||
self.cpu_block_pool.free_blocks(
|
||||
self.cpu_block_pool.blocks[bid]
|
||||
for bid in state.transfer_meta.cpu_block_ids
|
||||
)
|
||||
# Free GPU touch refs
|
||||
assert self._gpu_block_pool is not None
|
||||
self._gpu_block_pool.free_blocks(
|
||||
self._gpu_block_pool.blocks[bid]
|
||||
for bid in state.transfer_meta.gpu_block_ids
|
||||
)
|
||||
|
||||
def _cleanup_store_request(self, req_id: str) -> None:
|
||||
"""Release store metadata for a request.
|
||||
|
||||
Metadata-only cleanup but no block freeing. Job completion handles
|
||||
block caching and GPU ref freeing via _process_store_completion().
|
||||
"""
|
||||
state = self._reqs_to_store.pop(req_id, None)
|
||||
if state is None:
|
||||
return
|
||||
for event_idx in list(state.store_events):
|
||||
if (reqs := self._store_event_to_reqs.get(event_idx)) is not None:
|
||||
with contextlib.suppress(ValueError):
|
||||
reqs.remove(req_id)
|
||||
if not reqs:
|
||||
self._store_event_to_reqs.pop(event_idx, None)
|
||||
state.store_events.clear()
|
||||
|
||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||
return self.cpu_block_pool.take_events()
|
||||
60
vllm/v1/simple_kv_offload/metadata.py
Normal file
60
vllm/v1/simple_kv_offload/metadata.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Metadata for SimpleCPUOffloadConnector."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorMetadata,
|
||||
KVConnectorWorkerMetadata,
|
||||
)
|
||||
|
||||
INVALID_JOB_ID = -1
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleCPUOffloadMetadata(KVConnectorMetadata):
|
||||
"""
|
||||
Metadata passed from scheduler to worker for CPU offload operations.
|
||||
|
||||
The worker receives flat block lists keyed by a monotonic event_idx.
|
||||
Job->req_id translation is handled by the scheduler-side manager
|
||||
(via inverse maps), so the worker never knows about request identities.
|
||||
"""
|
||||
|
||||
# Load event per step. INVALID_JOB_ID means no blocks to load this step.
|
||||
load_event: int = INVALID_JOB_ID
|
||||
load_gpu_blocks: list[int] = field(default_factory=list)
|
||||
load_cpu_blocks: list[int] = field(default_factory=list)
|
||||
# Reverse map: load_event->req_ids, for tracking requests with finished load events
|
||||
load_event_to_reqs: dict[int, list[str]] = field(default_factory=dict)
|
||||
|
||||
# Store event per step. INVALID_JOB_ID means no blocks to store this step.
|
||||
store_event: int = INVALID_JOB_ID
|
||||
store_gpu_blocks: list[int] = field(default_factory=list)
|
||||
store_cpu_blocks: list[int] = field(default_factory=list)
|
||||
|
||||
# Whether any requests were preempted this step and need flush pending transfers.
|
||||
need_flush: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleCPUOffloadWorkerMetadata(KVConnectorWorkerMetadata):
|
||||
"""Worker -> Scheduler metadata for completed store events.
|
||||
|
||||
Each worker reports {event_idx: 1} for newly completed stores.
|
||||
``aggregate()`` sums counts across workers within a step.
|
||||
The scheduler-side manager accumulates across steps and processes
|
||||
a store completion only when count reaches ``world_size``.
|
||||
"""
|
||||
|
||||
completed_store_events: dict[int, int]
|
||||
|
||||
def aggregate(
|
||||
self, other: "KVConnectorWorkerMetadata"
|
||||
) -> "KVConnectorWorkerMetadata":
|
||||
assert isinstance(other, SimpleCPUOffloadWorkerMetadata)
|
||||
merged = dict(self.completed_store_events)
|
||||
for k, v in other.completed_store_events.items():
|
||||
merged[k] = merged.get(k, 0) + v
|
||||
return SimpleCPUOffloadWorkerMetadata(completed_store_events=merged)
|
||||
305
vllm/v1/simple_kv_offload/worker.py
Normal file
305
vllm/v1/simple_kv_offload/worker.py
Normal file
@@ -0,0 +1,305 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Worker-side handler for SimpleCPUOffloadConnector."""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.simple_kv_offload.copy_backend import DmaCopyBackend
|
||||
from vllm.v1.simple_kv_offload.cuda_mem_ops import pin_tensor
|
||||
from vllm.v1.simple_kv_offload.metadata import (
|
||||
SimpleCPUOffloadMetadata,
|
||||
SimpleCPUOffloadWorkerMetadata,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SimpleCPUOffloadWorker:
|
||||
"""Worker-side handler for CPU offloading transfers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_config: "KVCacheConfig | None",
|
||||
cpu_capacity_bytes: int,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.cpu_capacity_bytes = cpu_capacity_bytes
|
||||
|
||||
self.gpu_kv_caches: dict[str, torch.Tensor] | None = None
|
||||
self.cpu_kv_caches: dict[str, torch.Tensor] | None = None
|
||||
self.device: torch.device | None = None
|
||||
self.num_cpu_blocks: int = 0
|
||||
|
||||
# CUDA streams for the async transfers
|
||||
self.load_stream: torch.cuda.Stream | None = None
|
||||
self.store_stream: torch.cuda.Stream | None = None
|
||||
|
||||
self._backend = DmaCopyBackend()
|
||||
|
||||
# Ordered (event_idx, Event). Events pre-allocated on main thread.
|
||||
self._load_events: list[tuple[int, torch.Event]] = []
|
||||
self._store_events: list[tuple[int, torch.Event]] = []
|
||||
# High-water marks: highest event_idx completed per stream.
|
||||
# When the event list is empty, the hwm covers all prior events.
|
||||
self._load_hwm: int = -1
|
||||
self._store_hwm: int = -1
|
||||
|
||||
# Metadata for the current step
|
||||
self._connector_metadata: SimpleCPUOffloadMetadata | None = None
|
||||
|
||||
# Pending event index sets, populated in bind_connector_metadata
|
||||
self._pending_load_event_indices: set[int] = set()
|
||||
self._pending_store_event_indices: set[int] = set()
|
||||
# Completed store events to report via build_connector_worker_meta
|
||||
self._completed_store_events: dict[int, int] = {}
|
||||
|
||||
def register_kv_caches(
|
||||
self,
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
) -> None:
|
||||
"""Register GPU KV caches and allocate pinned CPU tensors.
|
||||
The worker will infer the underlying raw storage from the kv_caches.
|
||||
|
||||
Args:
|
||||
kv_caches: Per-layer GPU KV caches. Values are either a single
|
||||
tensor (attention layers) or a list of tensors (Mamba layers
|
||||
in hybrid models). All values are included for offloading
|
||||
by resolving to their underlying raw storage.
|
||||
"""
|
||||
if not kv_caches:
|
||||
logger.warning("No KV caches to offload.")
|
||||
return
|
||||
|
||||
# Resolve each entry to a representative tensor for storage
|
||||
# deduplication. For attention layers the value is already a tensor;
|
||||
# for Mamba layers it is a list of tensors that all share the same
|
||||
# underlying raw storage, so we take the first one.
|
||||
def _repr_tensor(v: torch.Tensor | list[torch.Tensor]) -> torch.Tensor:
|
||||
assert isinstance(v, torch.Tensor | list)
|
||||
return v if isinstance(v, torch.Tensor) else v[0]
|
||||
|
||||
any_tensor = _repr_tensor(next(iter(kv_caches.values())))
|
||||
self.device = any_tensor.device
|
||||
|
||||
assert self.kv_cache_config is not None
|
||||
num_blocks = self.kv_cache_config.num_blocks
|
||||
|
||||
# Deduplicate: multiple layers may share the same backing storage.
|
||||
seen_ptrs: dict[int, tuple[str, torch.Tensor]] = {}
|
||||
for name, value in kv_caches.items():
|
||||
tensor = _repr_tensor(value)
|
||||
ptr = tensor.untyped_storage().data_ptr()
|
||||
if ptr not in seen_ptrs:
|
||||
seen_ptrs[ptr] = (name, tensor)
|
||||
|
||||
# Build [num_blocks, block_bytes] int8 views from each unique
|
||||
# storage so that stride(0) gives block_bytes for the copy op.
|
||||
#
|
||||
# The physical layout varies across attention backends:
|
||||
# FlashAttn/ROCm: (2, num_blocks, ...) -> K/V outermost, 2 segments
|
||||
# FlashInfer/MLA: (num_blocks, ...) -> blocks outermost, 1 segment
|
||||
# We derive page_size_bytes = storage.nbytes() // num_blocks, then
|
||||
# classify dims: any dim whose byte-stride exceeds page_size_bytes
|
||||
# must be an outer segment dim (e.g. the K/V dim of size 2). A less
|
||||
# hacky way is to update the interface with the layout.
|
||||
unique_gpu_caches: dict[str, torch.Tensor] = {}
|
||||
for name, tensor in seen_ptrs.values():
|
||||
storage = tensor.untyped_storage()
|
||||
raw = torch.empty(0, dtype=torch.int8, device=self.device).set_(
|
||||
storage, 0, (storage.nbytes(),)
|
||||
)
|
||||
el = tensor.element_size()
|
||||
page_size_bytes = storage.nbytes() // num_blocks
|
||||
outer_dims = [
|
||||
d for d in range(tensor.ndim) if tensor.stride(d) * el > page_size_bytes
|
||||
]
|
||||
if not outer_dims:
|
||||
unique_gpu_caches[name] = raw.view(num_blocks, -1)
|
||||
else:
|
||||
seg_stride = tensor.stride(outer_dims[0]) * el
|
||||
for idx in range(tensor.shape[outer_dims[0]]):
|
||||
offset = idx * seg_stride
|
||||
chunk = raw[offset : offset + seg_stride]
|
||||
unique_gpu_caches[f"{name}.{idx}"] = chunk.view(num_blocks, -1)
|
||||
|
||||
# Compute per-tensor bytes_per_block. Tensors may have different
|
||||
# page_size_bytes (e.g., UniformTypeKVCacheSpecs with varying head_size).
|
||||
per_tensor_bpb = [
|
||||
t.stride(0) * t.element_size() for t in unique_gpu_caches.values()
|
||||
]
|
||||
total_bytes_per_block = sum(per_tensor_bpb)
|
||||
|
||||
self.num_cpu_blocks = max(1, self.cpu_capacity_bytes // total_bytes_per_block)
|
||||
|
||||
logger.info(
|
||||
"SimpleCPUOffloadWorker: %d unique GPU KV tensors, "
|
||||
"allocating %d CPU blocks (%.2f GB)",
|
||||
len(unique_gpu_caches),
|
||||
self.num_cpu_blocks,
|
||||
(self.num_cpu_blocks * total_bytes_per_block) / (1024**3),
|
||||
)
|
||||
|
||||
pin_memory = is_pin_memory_available()
|
||||
if not pin_memory:
|
||||
logger.warning(
|
||||
"Pinned memory not available. CPU offload performance may be degraded."
|
||||
)
|
||||
|
||||
self.gpu_kv_caches = unique_gpu_caches
|
||||
self.cpu_kv_caches = {}
|
||||
for name, gpu_tensor in unique_gpu_caches.items():
|
||||
cpu_shape = (self.num_cpu_blocks,) + gpu_tensor.shape[1:]
|
||||
# Allocate non-pinned first, then pin via cudaHostRegister to
|
||||
# bypass PyTorch's CUDACachingHostAllocator which rounds up to
|
||||
# the next power of 2 (e.g. 100 GB -> 128 GB).
|
||||
tensor = torch.zeros(cpu_shape, dtype=gpu_tensor.dtype, device="cpu")
|
||||
if pin_memory:
|
||||
pin_tensor(tensor)
|
||||
self.cpu_kv_caches[name] = tensor
|
||||
|
||||
# Use lowest priority so KV cache I/O yields to compute streams.
|
||||
low_pri, _ = torch.cuda.Stream.priority_range()
|
||||
self.load_stream = torch.cuda.Stream(priority=low_pri)
|
||||
self.store_stream = torch.cuda.Stream(priority=low_pri)
|
||||
|
||||
# Initialize copy backend with caches and streams.
|
||||
self._backend.init(
|
||||
self.gpu_kv_caches,
|
||||
self.cpu_kv_caches,
|
||||
self.device,
|
||||
self.load_stream,
|
||||
self.store_stream,
|
||||
)
|
||||
|
||||
def bind_connector_metadata(self, metadata: SimpleCPUOffloadMetadata) -> None:
|
||||
self._connector_metadata = metadata
|
||||
if metadata.load_event >= 0:
|
||||
self._pending_load_event_indices.add(metadata.load_event)
|
||||
if metadata.store_event >= 0:
|
||||
self._pending_store_event_indices.add(metadata.store_event)
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
self._connector_metadata = None
|
||||
|
||||
def start_load_kv(self) -> None:
|
||||
# NOTE: we defer launching both load and store to get_finished(),
|
||||
# which runs after model execution. This hides the CPU-side
|
||||
# block copy op overhead (~5ms) behind GPU compute.
|
||||
pass
|
||||
|
||||
def wait_for_save(self) -> None:
|
||||
pass
|
||||
|
||||
def get_finished(
|
||||
self,
|
||||
finished_req_ids: set[str],
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""Submit transfers and report completed events to the scheduler.
|
||||
|
||||
Called after model execution. The manager only schedules stores for
|
||||
blocks whose KV data is confirmed computed, so we launch both loads
|
||||
and stores immediately — no deferral or cross-stream sync needed.
|
||||
|
||||
Returns:
|
||||
tuple of (finished_sending, finished_recving).
|
||||
- finished_sending: always None (stores use worker metadata).
|
||||
- finished_recving: req_ids whose loads have completed.
|
||||
"""
|
||||
# (1) Submit transfers
|
||||
metadata = self._connector_metadata
|
||||
if metadata is not None:
|
||||
# Launch loads (CPU->GPU).
|
||||
if metadata.load_cpu_blocks:
|
||||
self._backend.launch_copy(
|
||||
metadata.load_cpu_blocks,
|
||||
metadata.load_gpu_blocks,
|
||||
is_store=False,
|
||||
event_idx=metadata.load_event,
|
||||
events_list=self._load_events,
|
||||
)
|
||||
# Launch stores (GPU->CPU).
|
||||
if metadata.store_gpu_blocks:
|
||||
self._backend.launch_copy(
|
||||
metadata.store_gpu_blocks,
|
||||
metadata.store_cpu_blocks,
|
||||
is_store=True,
|
||||
event_idx=metadata.store_event,
|
||||
events_list=self._store_events,
|
||||
)
|
||||
|
||||
# (2) Track completed transfer events
|
||||
finished_recving: set[str] = set()
|
||||
|
||||
if self._pending_load_event_indices:
|
||||
load_wm = self._poll_stream_events(is_store=False)
|
||||
for j in [j for j in self._pending_load_event_indices if j <= load_wm]:
|
||||
self._pending_load_event_indices.discard(j)
|
||||
req_ids = (
|
||||
metadata.load_event_to_reqs.get(j) if metadata is not None else None
|
||||
)
|
||||
if req_ids:
|
||||
finished_recving.update(req_ids)
|
||||
|
||||
if self._pending_store_event_indices:
|
||||
store_wm = self._poll_stream_events(is_store=True)
|
||||
for j in [j for j in self._pending_store_event_indices if j <= store_wm]:
|
||||
self._pending_store_event_indices.discard(j)
|
||||
self._completed_store_events[j] = 1
|
||||
|
||||
return None, finished_recving or None
|
||||
|
||||
def build_connector_worker_meta(self) -> SimpleCPUOffloadWorkerMetadata | None:
|
||||
"""Return completed store events since the last call."""
|
||||
if not self._completed_store_events:
|
||||
return None
|
||||
meta = SimpleCPUOffloadWorkerMetadata(
|
||||
completed_store_events=self._completed_store_events,
|
||||
)
|
||||
self._completed_store_events = {}
|
||||
return meta
|
||||
|
||||
def handle_preemptions(
|
||||
self, kv_connector_metadata: SimpleCPUOffloadMetadata
|
||||
) -> None:
|
||||
"""Sync all in-flight transfers before preempted blocks are reused."""
|
||||
if not kv_connector_metadata.need_flush:
|
||||
return
|
||||
self._flush_and_sync_all()
|
||||
|
||||
def _flush_and_sync_all(self) -> None:
|
||||
"""Synchronize all in-flight transfer events."""
|
||||
for event_idx, event in self._load_events:
|
||||
event.synchronize()
|
||||
self._load_hwm = event_idx
|
||||
self._load_events.clear()
|
||||
|
||||
for event_idx, event in self._store_events:
|
||||
event.synchronize()
|
||||
self._store_hwm = event_idx
|
||||
self._store_events.clear()
|
||||
|
||||
def _poll_stream_events(self, is_store: bool) -> int:
|
||||
"""Non-blocking poll for completed events and return the high-water mark."""
|
||||
events = self._store_events if is_store else self._load_events
|
||||
hwm = self._store_hwm if is_store else self._load_hwm
|
||||
while events:
|
||||
event_idx, event = events[0]
|
||||
if not event.query():
|
||||
break
|
||||
hwm = event_idx
|
||||
events.pop(0)
|
||||
if is_store:
|
||||
self._store_hwm = hwm
|
||||
else:
|
||||
self._load_hwm = hwm
|
||||
return hwm
|
||||
@@ -818,7 +818,6 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
def prepare_next_token_ids_padded(
|
||||
self,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
requests: dict[str, CachedRequestState],
|
||||
gpu_input_batch: InputBatch,
|
||||
@@ -833,7 +832,7 @@ class SpecDecodeBaseProposer:
|
||||
"""
|
||||
# Precompute get_token_id for when there is no valid next token
|
||||
num_reqs = gpu_input_batch.num_reqs
|
||||
seq_lens_list = seq_lens_cpu[:num_reqs].tolist()
|
||||
seq_lens_list = (gpu_input_batch.num_tokens_no_spec[:num_reqs] - 1).tolist()
|
||||
self.backup_next_token_ids.np[:num_reqs] = np.array(
|
||||
[
|
||||
requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i])
|
||||
|
||||
@@ -286,7 +286,6 @@ class ExtractHiddenStatesProposer:
|
||||
|
||||
def prepare_next_token_ids_padded(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
requests: dict[str, CachedRequestState],
|
||||
gpu_input_batch: InputBatch,
|
||||
@@ -303,7 +302,7 @@ class ExtractHiddenStatesProposer:
|
||||
device = sampled_token_ids.device
|
||||
|
||||
# Compute backup tokens for discarded / invalid requests
|
||||
seq_lens_list = seq_lens[:num_reqs].tolist()
|
||||
seq_lens_list = (gpu_input_batch.num_tokens_no_spec[:num_reqs] - 1).tolist()
|
||||
backup_tokens_gpu = torch.tensor(
|
||||
[
|
||||
requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i])
|
||||
|
||||
@@ -108,6 +108,15 @@ class CPUWorker(Worker):
|
||||
if ret:
|
||||
logger.info(ret)
|
||||
|
||||
# After the thread binding, changing thread num is not allowed
|
||||
def skip_set_num_threads(x: int):
|
||||
logger.warning(
|
||||
"CPU backend doesn't allow to use "
|
||||
"`torch.set_num_threads` after the thread binding, skip it."
|
||||
)
|
||||
|
||||
torch.set_num_threads = skip_set_num_threads
|
||||
|
||||
# Note: unique identifier for creating allreduce shared memory
|
||||
os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(":")[-1]
|
||||
# Initialize the distributed environment.
|
||||
|
||||
@@ -208,7 +208,7 @@ from .utils import (
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.worker.encoder_cudagraph import EncoderCudaGraphManager
|
||||
from vllm.v1.worker.gpu.mm.encoder_cudagraph import EncoderCudaGraphManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -1933,9 +1933,24 @@ class GPUModelRunner(
|
||||
# _update_states_after_model_execute for hybrid models).
|
||||
if self.num_accepted_tokens_event is not None:
|
||||
self.num_accepted_tokens_event.synchronize()
|
||||
self.num_accepted_tokens.np[:num_reqs] = (
|
||||
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
|
||||
)
|
||||
# Async mode: condense() reordered indices, use prev_positions mapping
|
||||
if self.use_async_scheduling and prev_req_id_to_index:
|
||||
prev_idx = self.prev_positions.np[:num_reqs]
|
||||
new_mask = prev_idx < 0
|
||||
self.num_accepted_tokens.np[:num_reqs] = (
|
||||
self.input_batch.num_accepted_tokens_cpu[
|
||||
np.where(new_mask, 0, prev_idx)
|
||||
]
|
||||
)
|
||||
self.num_accepted_tokens.np[:num_reqs][new_mask] = 1
|
||||
self.input_batch.num_accepted_tokens_cpu[:num_reqs] = (
|
||||
self.num_accepted_tokens.np[:num_reqs]
|
||||
)
|
||||
else:
|
||||
# Non-async mode: use values directly
|
||||
self.num_accepted_tokens.np[:num_reqs] = (
|
||||
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
|
||||
)
|
||||
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
||||
self.num_accepted_tokens.copy_to_gpu()
|
||||
else:
|
||||
@@ -4211,7 +4226,6 @@ class GPUModelRunner(
|
||||
assert spec_decode_common_attn_metadata is not None
|
||||
next_token_ids, valid_sampled_tokens_count = (
|
||||
self.drafter.prepare_next_token_ids_padded(
|
||||
self.optimistic_seq_lens_cpu,
|
||||
sampled_token_ids,
|
||||
self.requests,
|
||||
self.input_batch,
|
||||
@@ -4578,7 +4592,6 @@ class GPUModelRunner(
|
||||
)
|
||||
next_token_ids, valid_sampled_tokens_count = (
|
||||
self.drafter.prepare_next_token_ids_padded(
|
||||
self.optimistic_seq_lens_cpu,
|
||||
sampled_token_ids,
|
||||
self.requests,
|
||||
self.input_batch,
|
||||
@@ -4617,7 +4630,6 @@ class GPUModelRunner(
|
||||
)
|
||||
next_token_ids, valid_sampled_tokens_count = (
|
||||
self.drafter.prepare_next_token_ids_padded(
|
||||
self.optimistic_seq_lens_cpu,
|
||||
sampled_token_ids,
|
||||
self.requests,
|
||||
self.input_batch,
|
||||
@@ -5969,7 +5981,9 @@ class GPUModelRunner(
|
||||
SupportsEncoderCudaGraph,
|
||||
supports_encoder_cudagraph,
|
||||
)
|
||||
from vllm.v1.worker.encoder_cudagraph import EncoderCudaGraphManager
|
||||
from vllm.v1.worker.gpu.mm.encoder_cudagraph import (
|
||||
EncoderCudaGraphManager,
|
||||
)
|
||||
|
||||
raw_model = self.get_model()
|
||||
if supports_encoder_cudagraph(raw_model):
|
||||
|
||||
Reference in New Issue
Block a user