Compare commits
16 Commits
v0.17.0rc0
...
v0.17.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
95c0f928cd | ||
|
|
c9b1e977dc | ||
|
|
1ff2393897 | ||
|
|
5bec0b0ba3 | ||
|
|
6da1310f91 | ||
|
|
bc46be5daf | ||
|
|
8e39d39fd4 | ||
|
|
46fa044cc1 | ||
|
|
ab43e37158 | ||
|
|
f45d010120 | ||
|
|
244b922088 | ||
|
|
b31e9326a7 | ||
|
|
e346c08560 | ||
|
|
b7a423cb01 | ||
|
|
fa78ec8a72 | ||
|
|
9a474ce7a4 |
@@ -44,7 +44,7 @@ docker run \
|
|||||||
python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2
|
python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2
|
||||||
python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel
|
python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel
|
||||||
cd tests
|
cd tests
|
||||||
pytest -v -s v1/core --ignore=v1/core/test_reset_prefix_cache_e2e.py
|
pytest -v -s v1/core --ignore=v1/core/test_reset_prefix_cache_e2e.py --ignore=v1/core/test_scheduler_e2e.py
|
||||||
pytest -v -s v1/engine
|
pytest -v -s v1/engine
|
||||||
pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py
|
pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py
|
||||||
pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py
|
pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py
|
||||||
|
|||||||
@@ -54,10 +54,13 @@ mkdir -p $DIST_DIR
|
|||||||
# include only wheels for the release version, ignore all files with "dev" or "rc" in the name (without excluding 'aarch64')
|
# include only wheels for the release version, ignore all files with "dev" or "rc" in the name (without excluding 'aarch64')
|
||||||
aws s3 cp --recursive --exclude "*" --include "vllm-${PURE_VERSION}*.whl" --exclude "*dev*" --exclude "*rc[0-9]*" "$S3_COMMIT_PREFIX" $DIST_DIR
|
aws s3 cp --recursive --exclude "*" --include "vllm-${PURE_VERSION}*.whl" --exclude "*dev*" --exclude "*rc[0-9]*" "$S3_COMMIT_PREFIX" $DIST_DIR
|
||||||
echo "Wheels copied to local directory"
|
echo "Wheels copied to local directory"
|
||||||
# generate source tarball
|
# generate source distribution using setup.py
|
||||||
git archive --format=tar.gz --output="$DIST_DIR/vllm-${PURE_VERSION}.tar.gz" "$BUILDKITE_COMMIT"
|
python setup.py sdist --dist-dir=$DIST_DIR
|
||||||
ls -la $DIST_DIR
|
ls -la $DIST_DIR
|
||||||
|
|
||||||
|
SDIST_FILE=$(find $DIST_DIR -name "vllm*.tar.gz")
|
||||||
|
echo "Found sdist: $SDIST_FILE"
|
||||||
|
|
||||||
# upload wheels to PyPI (only default variant, i.e. files without '+' in the name)
|
# upload wheels to PyPI (only default variant, i.e. files without '+' in the name)
|
||||||
PYPI_WHEEL_FILES=$(find $DIST_DIR -name "vllm-${PURE_VERSION}*.whl" -not -name "*+*")
|
PYPI_WHEEL_FILES=$(find $DIST_DIR -name "vllm-${PURE_VERSION}*.whl" -not -name "*+*")
|
||||||
if [[ -z "$PYPI_WHEEL_FILES" ]]; then
|
if [[ -z "$PYPI_WHEEL_FILES" ]]; then
|
||||||
@@ -65,6 +68,6 @@ if [[ -z "$PYPI_WHEEL_FILES" ]]; then
|
|||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
python3 -m twine check "$PYPI_WHEEL_FILES"
|
python3 -m twine check "$PYPI_WHEEL_FILES" "$SDIST_FILE"
|
||||||
python3 -m twine upload --non-interactive --verbose "$PYPI_WHEEL_FILES"
|
python3 -m twine upload --non-interactive --verbose "$PYPI_WHEEL_FILES" "$SDIST_FILE"
|
||||||
echo "Wheels uploaded to PyPI"
|
echo "Wheels and source distribution uploaded to PyPI"
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ tokenizers >= 0.21.1 # Required for fast incremental detokenization.
|
|||||||
protobuf >= 5.29.6, !=6.30.*, !=6.31.*, !=6.32.*, !=6.33.0.*, !=6.33.1.*, !=6.33.2.*, !=6.33.3.*, !=6.33.4.* # Required by LlamaTokenizer, gRPC. CVE-2026-0994
|
protobuf >= 5.29.6, !=6.30.*, !=6.31.*, !=6.32.*, !=6.33.0.*, !=6.33.1.*, !=6.33.2.*, !=6.33.3.*, !=6.33.4.* # Required by LlamaTokenizer, gRPC. CVE-2026-0994
|
||||||
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
|
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
|
||||||
aiohttp >= 3.13.3
|
aiohttp >= 3.13.3
|
||||||
openai >= 1.99.1 # For Responses API with reasoning content
|
openai >= 1.99.1, < 2.25.0 # For Responses API with reasoning content
|
||||||
pydantic >= 2.12.0
|
pydantic >= 2.12.0
|
||||||
prometheus_client >= 0.18.0
|
prometheus_client >= 0.18.0
|
||||||
pillow # Required for image processing
|
pillow # Required for image processing
|
||||||
|
|||||||
@@ -11,6 +11,9 @@ torchaudio==2.10.0
|
|||||||
torchvision==0.25.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
torchvision==0.25.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||||
# FlashInfer should be updated together with the Dockerfile
|
# FlashInfer should be updated together with the Dockerfile
|
||||||
flashinfer-python==0.6.4
|
flashinfer-python==0.6.4
|
||||||
|
# Cap nvidia-cudnn-frontend (transitive dep of flashinfer) due to
|
||||||
|
# breaking changes in 1.19.0
|
||||||
|
nvidia-cudnn-frontend>=1.13.0,<1.19.0
|
||||||
|
|
||||||
# QuACK and Cutlass DSL for FA4 (cute-DSL implementation)
|
# QuACK and Cutlass DSL for FA4 (cute-DSL implementation)
|
||||||
nvidia-cutlass-dsl>=4.4.0.dev1
|
nvidia-cutlass-dsl>=4.4.0.dev1
|
||||||
|
|||||||
@@ -15,4 +15,4 @@ torch==2.10.0+xpu
|
|||||||
torchaudio
|
torchaudio
|
||||||
torchvision
|
torchvision
|
||||||
|
|
||||||
vllm_xpu_kernels @ https://github.com/vllm-project/vllm-xpu-kernels/releases/download/v0.1.2/vllm_xpu_kernels-0.1.2-cp312-cp312-linux_x86_64.whl
|
vllm_xpu_kernels @ https://github.com/vllm-project/vllm-xpu-kernels/releases/download/v0.1.3/vllm_xpu_kernels-0.1.3-cp38-abi3-linux_x86_64.whl
|
||||||
|
|||||||
172
tests/reasoning/test_nemotron_v3_reasoning_parser.py
Normal file
172
tests/reasoning/test_nemotron_v3_reasoning_parser.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import regex as re
|
||||||
|
|
||||||
|
from tests.reasoning.utils import run_reasoning_extraction
|
||||||
|
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||||
|
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||||
|
|
||||||
|
parser_name = "nemotron_v3"
|
||||||
|
|
||||||
|
|
||||||
|
class ReasoningCase(TypedDict):
|
||||||
|
output: str
|
||||||
|
reasoning: str | None
|
||||||
|
content: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class FakeNemotronTokenizer:
|
||||||
|
def __init__(self):
|
||||||
|
self._vocab = {
|
||||||
|
"<think>": 1,
|
||||||
|
"</think>": 2,
|
||||||
|
}
|
||||||
|
self._pattern = re.compile(r"(<think>|</think>)")
|
||||||
|
|
||||||
|
def get_vocab(self) -> dict[str, int]:
|
||||||
|
return self._vocab
|
||||||
|
|
||||||
|
def tokenize(self, text: str) -> list[str]:
|
||||||
|
tokens: list[str] = []
|
||||||
|
for part in self._pattern.split(text):
|
||||||
|
if part:
|
||||||
|
tokens.append(part)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
||||||
|
return "".join(tokens)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tokenizer():
|
||||||
|
return FakeNemotronTokenizer()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"streaming,param_dict",
|
||||||
|
[
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
{
|
||||||
|
"output": "This is a reasoning section</think>This is the rest",
|
||||||
|
"reasoning": "This is a reasoning section",
|
||||||
|
"content": "This is the rest",
|
||||||
|
},
|
||||||
|
id="without_start_token",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
{
|
||||||
|
"output": "This is a reasoning section</think>This is the rest",
|
||||||
|
"reasoning": "This is a reasoning section",
|
||||||
|
"content": "This is the rest",
|
||||||
|
},
|
||||||
|
id="without_start_token_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
{
|
||||||
|
"output": "<think>This is a reasoning section</think>This is the rest",
|
||||||
|
"reasoning": "This is a reasoning section",
|
||||||
|
"content": "This is the rest",
|
||||||
|
},
|
||||||
|
id="with_start_token",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
{
|
||||||
|
"output": "<think>This is a reasoning section</think>This is the rest",
|
||||||
|
"reasoning": "This is a reasoning section",
|
||||||
|
"content": "This is the rest",
|
||||||
|
},
|
||||||
|
id="with_start_token_streaming",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_nemotron_v3_reasoning(
|
||||||
|
tokenizer: FakeNemotronTokenizer,
|
||||||
|
streaming: bool,
|
||||||
|
param_dict: ReasoningCase,
|
||||||
|
):
|
||||||
|
output = tokenizer.tokenize(param_dict["output"])
|
||||||
|
model_output = [tokenizer.convert_tokens_to_string([token]) for token in output]
|
||||||
|
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
|
||||||
|
tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction(
|
||||||
|
parser, model_output, streaming=streaming
|
||||||
|
)
|
||||||
|
|
||||||
|
assert reasoning == param_dict["reasoning"]
|
||||||
|
assert content == param_dict["content"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_nemotron_v3_without_thinking_returns_content(
|
||||||
|
tokenizer: FakeNemotronTokenizer,
|
||||||
|
):
|
||||||
|
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||||
|
parser = parser_cls(tokenizer)
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model="test-model",
|
||||||
|
messages=[],
|
||||||
|
chat_template_kwargs={"enable_thinking": False},
|
||||||
|
)
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction(
|
||||||
|
parser,
|
||||||
|
["This is plain content"],
|
||||||
|
request=request,
|
||||||
|
streaming=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert reasoning is None
|
||||||
|
assert content == "This is plain content"
|
||||||
|
|
||||||
|
|
||||||
|
def test_nemotron_v3_force_nonempty_content_returns_content(
|
||||||
|
tokenizer: FakeNemotronTokenizer,
|
||||||
|
):
|
||||||
|
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||||
|
parser = parser_cls(tokenizer)
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model="test-model",
|
||||||
|
messages=[],
|
||||||
|
chat_template_kwargs={"force_nonempty_content": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction(
|
||||||
|
parser,
|
||||||
|
["<think>This is plain content"],
|
||||||
|
request=request,
|
||||||
|
streaming=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert reasoning is None
|
||||||
|
assert content == "This is plain content"
|
||||||
|
|
||||||
|
|
||||||
|
def test_nemotron_v3_with_thinking_keeps_truncated_reasoning(
|
||||||
|
tokenizer: FakeNemotronTokenizer,
|
||||||
|
):
|
||||||
|
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||||
|
parser = parser_cls(tokenizer)
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model="test-model",
|
||||||
|
messages=[],
|
||||||
|
chat_template_kwargs={"enable_thinking": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction(
|
||||||
|
parser,
|
||||||
|
["This is truncated reasoning"],
|
||||||
|
request=request,
|
||||||
|
streaming=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert reasoning == "This is truncated reasoning"
|
||||||
|
assert content is None
|
||||||
@@ -29,7 +29,8 @@ def test_tokenizer_like_protocol():
|
|||||||
_assert_tokenizer_like(tokenizer)
|
_assert_tokenizer_like(tokenizer)
|
||||||
|
|
||||||
tokenizer = get_tokenizer(
|
tokenizer = get_tokenizer(
|
||||||
"mistralai/Mistral-7B-Instruct-v0.3", tokenizer_mode="mistral"
|
"mistralai/Mistral-7B-Instruct-v0.3",
|
||||||
|
tokenizer_mode="mistral",
|
||||||
)
|
)
|
||||||
assert isinstance(tokenizer, MistralTokenizer)
|
assert isinstance(tokenizer, MistralTokenizer)
|
||||||
_assert_tokenizer_like(tokenizer)
|
_assert_tokenizer_like(tokenizer)
|
||||||
@@ -40,11 +41,20 @@ def test_tokenizer_like_protocol():
|
|||||||
|
|
||||||
tokenizer = get_tokenizer("deepseek-ai/DeepSeek-V3", tokenizer_mode="deepseek_v32")
|
tokenizer = get_tokenizer("deepseek-ai/DeepSeek-V3", tokenizer_mode="deepseek_v32")
|
||||||
assert isinstance(tokenizer, HfTokenizer)
|
assert isinstance(tokenizer, HfTokenizer)
|
||||||
|
|
||||||
# Verify it's a fast tokenizer (required for FastIncrementalDetokenizer)
|
# Verify it's a fast tokenizer (required for FastIncrementalDetokenizer)
|
||||||
assert isinstance(tokenizer, PreTrainedTokenizerFast)
|
assert isinstance(tokenizer, PreTrainedTokenizerFast)
|
||||||
assert "DSV32" in tokenizer.__class__.__name__
|
assert "DSV32" in tokenizer.__class__.__name__
|
||||||
_assert_tokenizer_like(tokenizer)
|
_assert_tokenizer_like(tokenizer)
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(
|
||||||
|
"Qwen/Qwen-VL",
|
||||||
|
tokenizer_mode="qwen_vl",
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
assert isinstance(tokenizer, HfTokenizer)
|
||||||
|
assert "WithoutImagePad" in tokenizer.__class__.__name__
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("tokenizer_name", ["facebook/opt-125m", "gpt2"])
|
@pytest.mark.parametrize("tokenizer_name", ["facebook/opt-125m", "gpt2"])
|
||||||
def test_tokenizer_revision(tokenizer_name: str):
|
def test_tokenizer_revision(tokenizer_name: str):
|
||||||
|
|||||||
@@ -1321,6 +1321,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
|||||||
- "slow" will always use the slow tokenizer.\n
|
- "slow" will always use the slow tokenizer.\n
|
||||||
- "mistral" will always use the tokenizer from `mistral_common`.\n
|
- "mistral" will always use the tokenizer from `mistral_common`.\n
|
||||||
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
|
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
|
||||||
|
- "qwen_vl" will always use the tokenizer from `qwen_vl`.\n
|
||||||
- Other custom values can be supported via plugins.""",
|
- Other custom values can be supported via plugins.""",
|
||||||
)
|
)
|
||||||
parser.add_argument("--use-beam-search", action="store_true")
|
parser.add_argument("--use-beam-search", action="store_true")
|
||||||
|
|||||||
@@ -126,6 +126,7 @@ class ModelConfig:
|
|||||||
- "slow" will always use the slow tokenizer.\n
|
- "slow" will always use the slow tokenizer.\n
|
||||||
- "mistral" will always use the tokenizer from `mistral_common`.\n
|
- "mistral" will always use the tokenizer from `mistral_common`.\n
|
||||||
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
|
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
|
||||||
|
- "qwen_vl" will always use the tokenizer from `qwen_vl`.\n
|
||||||
- Other custom values can be supported via plugins."""
|
- Other custom values can be supported via plugins."""
|
||||||
trust_remote_code: bool = False
|
trust_remote_code: bool = False
|
||||||
"""Trust remote code (e.g., from HuggingFace) when downloading the model
|
"""Trust remote code (e.g., from HuggingFace) when downloading the model
|
||||||
|
|||||||
@@ -35,12 +35,6 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
|
|||||||
):
|
):
|
||||||
super().__init__(moe_config, quant_config)
|
super().__init__(moe_config, quant_config)
|
||||||
|
|
||||||
if moe_config.moe_parallel_config.use_ep and quant_config.is_per_tensor:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"EP parallelism is not supported with TRTLLM"
|
|
||||||
"per-tensor FP8 quantization."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.routing_method_type = moe_config.routing_method
|
self.routing_method_type = moe_config.routing_method
|
||||||
self.topk = moe_config.experts_per_token
|
self.topk = moe_config.experts_per_token
|
||||||
self.intermediate_size_per_partition = (
|
self.intermediate_size_per_partition = (
|
||||||
@@ -182,9 +176,6 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
|
|||||||
assert not apply_router_weight_on_input
|
assert not apply_router_weight_on_input
|
||||||
assert activation == MoEActivation.SILU
|
assert activation == MoEActivation.SILU
|
||||||
|
|
||||||
if e_score_correction_bias is not None:
|
|
||||||
e_score_correction_bias = e_score_correction_bias.to(hidden_states.dtype)
|
|
||||||
|
|
||||||
if self.routing_method_type == RoutingMethodType.DeepSeekV3:
|
if self.routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||||
router_logits = router_logits.to(torch.float32)
|
router_logits = router_logits.to(torch.float32)
|
||||||
|
|
||||||
@@ -240,12 +231,11 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Delay import for non-CUDA.
|
# Delay import for non-CUDA.
|
||||||
import flashinfer
|
import flashinfer
|
||||||
from flashinfer.fused_moe.core import ActivationType
|
|
||||||
|
|
||||||
# Confirm supported activation function.
|
# Confirm supported activation function.
|
||||||
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||||
|
|
||||||
activation_type = ActivationType(activation_to_flashinfer_int(activation))
|
activation_type = activation_to_flashinfer_int(activation)
|
||||||
|
|
||||||
# Confirm Llama-4 routing is proper.
|
# Confirm Llama-4 routing is proper.
|
||||||
if self.routing_method_type == RoutingMethodType.Llama4:
|
if self.routing_method_type == RoutingMethodType.Llama4:
|
||||||
|
|||||||
@@ -323,4 +323,5 @@ class TrtLlmNvFp4ExpertsMonolithic(
|
|||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
routing_method_type=self.routing_method_type,
|
routing_method_type=self.routing_method_type,
|
||||||
do_finalize=True,
|
do_finalize=True,
|
||||||
|
activation_type=activation_to_flashinfer_int(activation),
|
||||||
)[0]
|
)[0]
|
||||||
|
|||||||
@@ -912,7 +912,7 @@ class BatchedTritonExperts(mk.FusedMoEExpertsModular):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _supports_no_act_and_mul() -> bool:
|
def _supports_no_act_and_mul() -> bool:
|
||||||
return False
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _supports_quant_scheme(
|
def _supports_quant_scheme(
|
||||||
|
|||||||
@@ -1944,7 +1944,7 @@ class TritonExperts(mk.FusedMoEExpertsModular):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _supports_no_act_and_mul() -> bool:
|
def _supports_no_act_and_mul() -> bool:
|
||||||
return False
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _supports_quant_scheme(
|
def _supports_quant_scheme(
|
||||||
@@ -1983,6 +1983,9 @@ class TritonExperts(mk.FusedMoEExpertsModular):
|
|||||||
MoEActivation.GELU,
|
MoEActivation.GELU,
|
||||||
MoEActivation.SWIGLUOAI,
|
MoEActivation.SWIGLUOAI,
|
||||||
MoEActivation.SWIGLUSTEP,
|
MoEActivation.SWIGLUSTEP,
|
||||||
|
MoEActivation.SILU_NO_MUL,
|
||||||
|
MoEActivation.GELU_NO_MUL,
|
||||||
|
MoEActivation.RELU2_NO_MUL,
|
||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -353,6 +353,39 @@ class Qwen2_5OmniThinkerProcessingInfo(
|
|||||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||||
return {"audio": None, "image": None, "video": None}
|
return {"audio": None, "image": None, "video": None}
|
||||||
|
|
||||||
|
def get_mm_max_tokens_per_item(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int] | None = None,
|
||||||
|
) -> Mapping[str, int] | None:
|
||||||
|
mm_counts = mm_counts or {}
|
||||||
|
requested_modalities = {m for m, c in mm_counts.items() if c > 0}
|
||||||
|
mm_max_tokens: dict[str, int] = {}
|
||||||
|
|
||||||
|
if requested_modalities & {"image", "video"}:
|
||||||
|
vl_tokens = Qwen2_5_VLProcessingInfo.get_mm_max_tokens_per_item(
|
||||||
|
self,
|
||||||
|
seq_len=seq_len,
|
||||||
|
mm_counts=mm_counts,
|
||||||
|
)
|
||||||
|
mm_max_tokens.update(
|
||||||
|
{
|
||||||
|
m: vl_tokens[m]
|
||||||
|
for m in ["image", "video"]
|
||||||
|
if m in requested_modalities
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if "audio" in requested_modalities:
|
||||||
|
audio_tokens = Qwen2AudioProcessingInfo.get_mm_max_tokens_per_item(
|
||||||
|
self,
|
||||||
|
seq_len=seq_len,
|
||||||
|
mm_counts=mm_counts,
|
||||||
|
)
|
||||||
|
mm_max_tokens["audio"] = audio_tokens["audio"]
|
||||||
|
|
||||||
|
return mm_max_tokens
|
||||||
|
|
||||||
|
|
||||||
class Qwen2_5OmniThinkerDummyInputsBuilder(
|
class Qwen2_5OmniThinkerDummyInputsBuilder(
|
||||||
BaseDummyInputsBuilder[Qwen2_5OmniThinkerProcessingInfo]
|
BaseDummyInputsBuilder[Qwen2_5OmniThinkerProcessingInfo]
|
||||||
|
|||||||
@@ -179,6 +179,26 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo):
|
|||||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||||
return {"audio": None}
|
return {"audio": None}
|
||||||
|
|
||||||
|
def get_mm_max_tokens_per_item(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int] | None = None,
|
||||||
|
) -> Mapping[str, int]:
|
||||||
|
mm_counts = mm_counts or {}
|
||||||
|
if mm_counts.get("audio", 0) <= 0:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
chunk_length = min(feature_extractor.chunk_length, 30)
|
||||||
|
audio_len = int(chunk_length * feature_extractor.sampling_rate)
|
||||||
|
hop_length = feature_extractor.hop_length
|
||||||
|
max_mel_seq_len = audio_len // hop_length
|
||||||
|
|
||||||
|
input_lengths = torch.tensor([max_mel_seq_len], dtype=torch.long)
|
||||||
|
_, output_lengths = _get_feat_extract_output_lengths(input_lengths)
|
||||||
|
|
||||||
|
return {"audio": int(output_lengths.item())}
|
||||||
|
|
||||||
|
|
||||||
class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):
|
class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):
|
||||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||||
|
|||||||
@@ -1163,6 +1163,39 @@ class Qwen3OmniMoeThinkerProcessingInfo(
|
|||||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||||
return {"audio": None, "image": None, "video": None}
|
return {"audio": None, "image": None, "video": None}
|
||||||
|
|
||||||
|
def get_mm_max_tokens_per_item(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int] | None = None,
|
||||||
|
) -> Mapping[str, int] | None:
|
||||||
|
mm_counts = mm_counts or {}
|
||||||
|
requested_modalities = {m for m, c in mm_counts.items() if c > 0}
|
||||||
|
mm_max_tokens: dict[str, int] = {}
|
||||||
|
|
||||||
|
if requested_modalities & {"image", "video"}:
|
||||||
|
vl_tokens = Qwen2_5_VLProcessingInfo.get_mm_max_tokens_per_item(
|
||||||
|
self,
|
||||||
|
seq_len=seq_len,
|
||||||
|
mm_counts=mm_counts,
|
||||||
|
)
|
||||||
|
mm_max_tokens.update(
|
||||||
|
{
|
||||||
|
m: vl_tokens[m]
|
||||||
|
for m in ["image", "video"]
|
||||||
|
if m in requested_modalities
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if "audio" in requested_modalities:
|
||||||
|
audio_tokens = Qwen2AudioProcessingInfo.get_mm_max_tokens_per_item(
|
||||||
|
self,
|
||||||
|
seq_len=seq_len,
|
||||||
|
mm_counts=mm_counts,
|
||||||
|
)
|
||||||
|
mm_max_tokens["audio"] = audio_tokens["audio"]
|
||||||
|
|
||||||
|
return mm_max_tokens
|
||||||
|
|
||||||
|
|
||||||
Qwen3OmniMoeThinkerDummyInputsBuilder = Qwen2_5OmniThinkerDummyInputsBuilder
|
Qwen3OmniMoeThinkerDummyInputsBuilder = Qwen2_5OmniThinkerDummyInputsBuilder
|
||||||
|
|
||||||
|
|||||||
@@ -6,11 +6,9 @@
|
|||||||
# Copyright (c) Alibaba Cloud.
|
# Copyright (c) Alibaba Cloud.
|
||||||
"""Inference-only Qwen-VL model compatible with HuggingFace weights."""
|
"""Inference-only Qwen-VL model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
import copy
|
|
||||||
import math
|
import math
|
||||||
import unicodedata
|
from collections.abc import Callable, Mapping, Sequence
|
||||||
from collections.abc import Callable, Collection, Mapping, Sequence, Set
|
from functools import partial
|
||||||
from functools import lru_cache, partial
|
|
||||||
from typing import Annotated, Literal, TypeAlias
|
from typing import Annotated, Literal, TypeAlias
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
@@ -436,60 +434,6 @@ class QwenVLModel(QWenModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
|
||||||
def _get_tokenizer_without_image_pad(
|
|
||||||
tokenizer: PreTrainedTokenizer,
|
|
||||||
) -> PreTrainedTokenizer:
|
|
||||||
"""
|
|
||||||
The logic of adding image pad tokens should only be applied in
|
|
||||||
[`QwenVLProcessor`][vllm.model_executor.models.qwen_vl.QwenVLProcessor],
|
|
||||||
so they are patched out here.
|
|
||||||
|
|
||||||
The definition of the wrapped tokenizer can be found here:
|
|
||||||
https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py
|
|
||||||
"""
|
|
||||||
new_tokenizer = copy.deepcopy(tokenizer)
|
|
||||||
|
|
||||||
class TokenizerWithoutImagePad(tokenizer.__class__): # type: ignore
|
|
||||||
def tokenize(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
allowed_special: Set[str] | str = "all",
|
|
||||||
disallowed_special: Collection[str] | str = (),
|
|
||||||
**kwargs,
|
|
||||||
) -> list[bytes | str]:
|
|
||||||
text = unicodedata.normalize("NFC", text)
|
|
||||||
|
|
||||||
return [
|
|
||||||
self.decoder[t]
|
|
||||||
for t in self.tokenizer.encode(
|
|
||||||
text,
|
|
||||||
allowed_special=allowed_special,
|
|
||||||
disallowed_special=disallowed_special,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def _decode(
|
|
||||||
self,
|
|
||||||
token_ids: int | list[int],
|
|
||||||
skip_special_tokens: bool = False,
|
|
||||||
errors: str | None = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> str:
|
|
||||||
if isinstance(token_ids, int):
|
|
||||||
token_ids = [token_ids]
|
|
||||||
|
|
||||||
return self.tokenizer.decode(
|
|
||||||
token_ids,
|
|
||||||
errors=errors or self.errors,
|
|
||||||
)
|
|
||||||
|
|
||||||
TokenizerWithoutImagePad.__name__ = f"{tokenizer.__class__.__name__}WithoutImagePad"
|
|
||||||
|
|
||||||
new_tokenizer.__class__ = TokenizerWithoutImagePad
|
|
||||||
return new_tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class QwenVLProcessor:
|
class QwenVLProcessor:
|
||||||
"""
|
"""
|
||||||
This model doesn't define its own HF processor,
|
This model doesn't define its own HF processor,
|
||||||
@@ -574,12 +518,6 @@ class QwenVLProcessor:
|
|||||||
|
|
||||||
|
|
||||||
class QwenVLProcessingInfo(BaseProcessingInfo):
|
class QwenVLProcessingInfo(BaseProcessingInfo):
|
||||||
def get_tokenizer(self) -> PreTrainedTokenizer:
|
|
||||||
tokenizer = self.ctx.get_tokenizer()
|
|
||||||
assert isinstance(tokenizer, PreTrainedTokenizer)
|
|
||||||
|
|
||||||
return _get_tokenizer_without_image_pad(tokenizer)
|
|
||||||
|
|
||||||
def get_hf_processor(self, **kwargs: object) -> QwenVLProcessor:
|
def get_hf_processor(self, **kwargs: object) -> QwenVLProcessor:
|
||||||
return self.ctx.init_processor(
|
return self.ctx.init_processor(
|
||||||
QwenVLProcessor,
|
QwenVLProcessor,
|
||||||
|
|||||||
@@ -68,6 +68,10 @@ _REASONING_PARSERS_TO_REGISTER = {
|
|||||||
"mistral_reasoning_parser",
|
"mistral_reasoning_parser",
|
||||||
"MistralReasoningParser",
|
"MistralReasoningParser",
|
||||||
),
|
),
|
||||||
|
"nemotron_v3": (
|
||||||
|
"nemotron_v3_reasoning_parser",
|
||||||
|
"NemotronV3ReasoningParser",
|
||||||
|
),
|
||||||
"olmo3": (
|
"olmo3": (
|
||||||
"olmo3_reasoning_parser",
|
"olmo3_reasoning_parser",
|
||||||
"Olmo3ReasoningParser",
|
"Olmo3ReasoningParser",
|
||||||
|
|||||||
35
vllm/reasoning/nemotron_v3_reasoning_parser.py
Normal file
35
vllm/reasoning/nemotron_v3_reasoning_parser.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.openai.responses.protocol import (
|
||||||
|
ResponsesRequest,
|
||||||
|
)
|
||||||
|
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
||||||
|
|
||||||
|
|
||||||
|
class NemotronV3ReasoningParser(DeepSeekR1ReasoningParser):
|
||||||
|
"""
|
||||||
|
Reasoning parser for Nemotron V3 models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def extract_reasoning(
|
||||||
|
self, model_output: str, request: ChatCompletionRequest | ResponsesRequest
|
||||||
|
) -> tuple[str | None, str | None]:
|
||||||
|
reasoning_content, final_content = super().extract_reasoning(
|
||||||
|
model_output, request
|
||||||
|
)
|
||||||
|
chat_template_kwargs = getattr(request, "chat_template_kwargs", None)
|
||||||
|
|
||||||
|
if (
|
||||||
|
chat_template_kwargs
|
||||||
|
and (
|
||||||
|
chat_template_kwargs.get("enable_thinking") is False
|
||||||
|
or chat_template_kwargs.get("force_nonempty_content") is True
|
||||||
|
)
|
||||||
|
and final_content is None
|
||||||
|
):
|
||||||
|
reasoning_content, final_content = final_content, reasoning_content
|
||||||
|
|
||||||
|
return reasoning_content, final_content
|
||||||
29
vllm/renderers/qwen_vl.py
Normal file
29
vllm/renderers/qwen_vl.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.tokenizers import cached_get_tokenizer
|
||||||
|
from vllm.tokenizers.qwen_vl import QwenVLTokenizer
|
||||||
|
|
||||||
|
from .base import BaseRenderer
|
||||||
|
from .hf import HfRenderer
|
||||||
|
|
||||||
|
|
||||||
|
class QwenVLRenderer(BaseRenderer[QwenVLTokenizer]):
|
||||||
|
@classmethod
|
||||||
|
def from_config( # type: ignore[override]
|
||||||
|
cls,
|
||||||
|
config: VllmConfig,
|
||||||
|
tokenizer_kwargs: dict[str, Any],
|
||||||
|
) -> "HfRenderer":
|
||||||
|
model_config = config.model_config
|
||||||
|
if model_config.skip_tokenizer_init:
|
||||||
|
tokenizer = None
|
||||||
|
else:
|
||||||
|
tokenizer = cached_get_tokenizer(
|
||||||
|
tokenizer_cls=QwenVLTokenizer,
|
||||||
|
**tokenizer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return HfRenderer(config, tokenizer)
|
||||||
@@ -20,6 +20,7 @@ _VLLM_RENDERERS = {
|
|||||||
"hf": ("hf", "HfRenderer"),
|
"hf": ("hf", "HfRenderer"),
|
||||||
"grok2": ("grok2", "Grok2Renderer"),
|
"grok2": ("grok2", "Grok2Renderer"),
|
||||||
"mistral": ("mistral", "MistralRenderer"),
|
"mistral": ("mistral", "MistralRenderer"),
|
||||||
|
"qwen_vl": ("qwen_vl", "QwenVLRenderer"),
|
||||||
"terratorch": ("terratorch", "TerratorchRenderer"),
|
"terratorch": ("terratorch", "TerratorchRenderer"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ from transformers import AutoTokenizer
|
|||||||
|
|
||||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||||
|
|
||||||
from . import TokenizerLike
|
|
||||||
from .deepseek_v32_encoding import encode_messages
|
from .deepseek_v32_encoding import encode_messages
|
||||||
from .hf import HfTokenizer, get_cached_tokenizer
|
from .hf import HfTokenizer, get_cached_tokenizer
|
||||||
|
from .protocol import TokenizerLike
|
||||||
|
|
||||||
|
|
||||||
def get_deepseek_v32_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
|
def get_deepseek_v32_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
|
||||||
|
|||||||
67
vllm/tokenizers/qwen_vl.py
Normal file
67
vllm/tokenizers/qwen_vl.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import copy
|
||||||
|
import unicodedata
|
||||||
|
from collections.abc import Collection, Set
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from .hf import HfTokenizer, get_cached_tokenizer
|
||||||
|
from .protocol import TokenizerLike
|
||||||
|
|
||||||
|
|
||||||
|
def get_qwen_vl_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
|
||||||
|
"""
|
||||||
|
The logic of adding image pad tokens should only be applied in
|
||||||
|
`QwenVLProcessor`, so they are patched out here.
|
||||||
|
|
||||||
|
The definition of the wrapped tokenizer can be found here:
|
||||||
|
https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py
|
||||||
|
"""
|
||||||
|
new_tokenizer = copy.copy(tokenizer)
|
||||||
|
|
||||||
|
class TokenizerWithoutImagePad(tokenizer.__class__): # type: ignore
|
||||||
|
def tokenize(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
allowed_special: Set[str] | str = "all",
|
||||||
|
disallowed_special: Collection[str] | str = (),
|
||||||
|
**kwargs,
|
||||||
|
) -> list[bytes | str]:
|
||||||
|
text = unicodedata.normalize("NFC", text)
|
||||||
|
|
||||||
|
return [
|
||||||
|
self.decoder[t]
|
||||||
|
for t in self.tokenizer.encode(
|
||||||
|
text,
|
||||||
|
allowed_special=allowed_special,
|
||||||
|
disallowed_special=disallowed_special,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def _decode(
|
||||||
|
self,
|
||||||
|
token_ids: int | list[int],
|
||||||
|
skip_special_tokens: bool = False,
|
||||||
|
errors: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
if isinstance(token_ids, int):
|
||||||
|
token_ids = [token_ids]
|
||||||
|
|
||||||
|
return self.tokenizer.decode(
|
||||||
|
token_ids,
|
||||||
|
errors=errors or self.errors,
|
||||||
|
)
|
||||||
|
|
||||||
|
TokenizerWithoutImagePad.__name__ = f"{tokenizer.__class__.__name__}WithoutImagePad"
|
||||||
|
|
||||||
|
new_tokenizer.__class__ = TokenizerWithoutImagePad
|
||||||
|
return new_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class QwenVLTokenizer(TokenizerLike):
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs) -> HfTokenizer:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(*args, **kwargs)
|
||||||
|
return get_cached_tokenizer(get_qwen_vl_tokenizer(tokenizer))
|
||||||
@@ -36,6 +36,7 @@ _VLLM_TOKENIZERS = {
|
|||||||
"grok2": ("grok2", "Grok2Tokenizer"),
|
"grok2": ("grok2", "Grok2Tokenizer"),
|
||||||
"hf": ("hf", "CachedHfTokenizer"),
|
"hf": ("hf", "CachedHfTokenizer"),
|
||||||
"mistral": ("mistral", "MistralTokenizer"),
|
"mistral": ("mistral", "MistralTokenizer"),
|
||||||
|
"qwen_vl": ("qwen_vl", "QwenVLTokenizer"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -165,6 +166,10 @@ def resolve_tokenizer_args(
|
|||||||
):
|
):
|
||||||
tokenizer_mode = "grok2"
|
tokenizer_mode = "grok2"
|
||||||
|
|
||||||
|
# Model-specific tokenizers
|
||||||
|
if tokenizer_mode == "auto" and "/Qwen-VL" in str(tokenizer_name):
|
||||||
|
tokenizer_mode = "qwen_vl"
|
||||||
|
|
||||||
# Fallback to HF tokenizer
|
# Fallback to HF tokenizer
|
||||||
if tokenizer_mode == "auto":
|
if tokenizer_mode == "auto":
|
||||||
tokenizer_mode = "hf"
|
tokenizer_mode = "hf"
|
||||||
|
|||||||
@@ -30,3 +30,8 @@ def round_up(x: int, y: int) -> int:
|
|||||||
def round_down(x: int, y: int) -> int:
|
def round_down(x: int, y: int) -> int:
|
||||||
"""Round down x to the nearest multiple of y."""
|
"""Round down x to the nearest multiple of y."""
|
||||||
return (x // y) * y
|
return (x // y) * y
|
||||||
|
|
||||||
|
|
||||||
|
def largest_power_of_2_divisor(n: int) -> int:
|
||||||
|
"""Return the largest power-of-2 that divides *n* (isolate lowest set bit)."""
|
||||||
|
return n & (-n)
|
||||||
|
|||||||
@@ -86,6 +86,26 @@ class AttentionBackend(ABC):
|
|||||||
) -> tuple[int, ...]:
|
) -> tuple[int, ...]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_kv_cache_block_dim(
|
||||||
|
cls,
|
||||||
|
block_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
cache_dtype_str: str = "auto",
|
||||||
|
) -> int:
|
||||||
|
"""Discover which tensor dim is the block index, since different
|
||||||
|
backends lay out dims differently."""
|
||||||
|
_S = 1234567
|
||||||
|
shape = cls.get_kv_cache_shape(
|
||||||
|
_S,
|
||||||
|
block_size,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size,
|
||||||
|
cache_dtype_str=cache_dtype_str,
|
||||||
|
)
|
||||||
|
return shape.index(_S)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_stride_order(
|
def get_kv_cache_stride_order(
|
||||||
include_num_layers_dimension: bool = False,
|
include_num_layers_dimension: bool = False,
|
||||||
|
|||||||
@@ -372,12 +372,14 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
|||||||
|
|
||||||
# [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8]
|
# [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8]
|
||||||
expanded_base = torch.repeat_interleave(
|
expanded_base = torch.repeat_interleave(
|
||||||
seq_lens - decode_lens, decode_lens
|
seq_lens - decode_lens, decode_lens, output_size=actual_expanded
|
||||||
)
|
)
|
||||||
|
|
||||||
# [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4]
|
# [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4]
|
||||||
expanded_starts = torch.repeat_interleave(
|
expanded_starts = torch.repeat_interleave(
|
||||||
common_attn_metadata.query_start_loc[:num_decodes], decode_lens
|
common_attn_metadata.query_start_loc[:num_decodes],
|
||||||
|
decode_lens,
|
||||||
|
output_size=actual_expanded,
|
||||||
)
|
)
|
||||||
|
|
||||||
# [0, 1, 2, 0, 0, 1, 2, 3]
|
# [0, 1, 2, 0, 0, 1, 2, 3]
|
||||||
@@ -395,7 +397,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
|||||||
# Give each of the flattened entries the same block table row as the
|
# Give each of the flattened entries the same block table row as the
|
||||||
# original request.
|
# original request.
|
||||||
self.expanded_block_table_buffer[:actual_expanded] = (
|
self.expanded_block_table_buffer[:actual_expanded] = (
|
||||||
torch.repeat_interleave(block_table, decode_lens, dim=0)
|
torch.repeat_interleave(
|
||||||
|
block_table, decode_lens, dim=0, output_size=actual_expanded
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if actual_expanded < num_decode_tokens:
|
if actual_expanded < num_decode_tokens:
|
||||||
self.expanded_block_table_buffer[
|
self.expanded_block_table_buffer[
|
||||||
|
|||||||
@@ -489,6 +489,13 @@ class KVCacheManager:
|
|||||||
# Only create new KVCacheBlocks for non-empty blocks
|
# Only create new KVCacheBlocks for non-empty blocks
|
||||||
return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks
|
return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks
|
||||||
|
|
||||||
|
def take_new_block_ids(self) -> list[int]:
|
||||||
|
"""Drain and return new attention block IDs for zeroing."""
|
||||||
|
ids: list[int] = []
|
||||||
|
for mgr in self.coordinator.single_type_managers:
|
||||||
|
ids.extend(mgr.take_new_block_ids())
|
||||||
|
return ids
|
||||||
|
|
||||||
def new_step_starts(self) -> None:
|
def new_step_starts(self) -> None:
|
||||||
"""Called when a new step is started."""
|
"""Called when a new step is started."""
|
||||||
self.coordinator.new_step_starts()
|
self.coordinator.new_step_starts()
|
||||||
|
|||||||
@@ -233,6 +233,11 @@ class SchedulerOutput:
|
|||||||
# EC Cache Connector metadata
|
# EC Cache Connector metadata
|
||||||
ec_connector_metadata: ECConnectorMetadata | None = None
|
ec_connector_metadata: ECConnectorMetadata | None = None
|
||||||
|
|
||||||
|
# Block IDs freshly allocated from the pool during this scheduling step.
|
||||||
|
# The worker zeros the corresponding GPU memory before the blocks are used,
|
||||||
|
# preventing stale NaN/data from corrupting attention or SSM computation.
|
||||||
|
new_block_ids_to_zero: list[int] | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_empty(cls) -> "SchedulerOutput":
|
def make_empty(cls) -> "SchedulerOutput":
|
||||||
return cls(
|
return cls(
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ from vllm.v1.core.sched.output import (
|
|||||||
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
|
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
|
||||||
from vllm.v1.core.sched.utils import check_stop, remove_all
|
from vllm.v1.core.sched.utils import check_stop, remove_all
|
||||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
|
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.metrics.perf import ModelMetrics, PerfStats
|
from vllm.v1.metrics.perf import ModelMetrics, PerfStats
|
||||||
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
|
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
|
||||||
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
|
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
|
||||||
@@ -233,13 +233,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||||
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
|
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
|
||||||
|
|
||||||
def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool:
|
self.has_mamba_layers = kv_cache_config.has_mamba_layers
|
||||||
return any(
|
self.needs_kv_cache_zeroing = kv_cache_config.needs_kv_cache_zeroing
|
||||||
isinstance(group_spec.kv_cache_spec, MambaSpec)
|
|
||||||
for group_spec in kv_cache_config.kv_cache_groups
|
|
||||||
)
|
|
||||||
|
|
||||||
self.has_mamba_layers = has_mamba_layers(kv_cache_config)
|
|
||||||
self.need_mamba_block_aligned_split = (
|
self.need_mamba_block_aligned_split = (
|
||||||
self.has_mamba_layers and self.cache_config.mamba_cache_mode == "align"
|
self.has_mamba_layers and self.cache_config.mamba_cache_mode == "align"
|
||||||
)
|
)
|
||||||
@@ -871,6 +866,12 @@ class Scheduler(SchedulerInterface):
|
|||||||
self.prev_step_scheduled_req_ids.clear()
|
self.prev_step_scheduled_req_ids.clear()
|
||||||
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())
|
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())
|
||||||
|
|
||||||
|
new_block_ids_to_zero = (
|
||||||
|
(self.kv_cache_manager.take_new_block_ids() or None)
|
||||||
|
if self.needs_kv_cache_zeroing
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
scheduler_output = SchedulerOutput(
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_new_reqs=new_reqs_data,
|
scheduled_new_reqs=new_reqs_data,
|
||||||
scheduled_cached_reqs=cached_reqs_data,
|
scheduled_cached_reqs=cached_reqs_data,
|
||||||
@@ -886,6 +887,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
# the previous and the current steps.
|
# the previous and the current steps.
|
||||||
finished_req_ids=self.finished_req_ids,
|
finished_req_ids=self.finished_req_ids,
|
||||||
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
|
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
|
||||||
|
new_block_ids_to_zero=new_block_ids_to_zero,
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE(Kuntai): this function is designed for multiple purposes:
|
# NOTE(Kuntai): this function is designed for multiple purposes:
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ class SingleTypeKVCacheManager(ABC):
|
|||||||
self.kv_cache_spec = kv_cache_spec
|
self.kv_cache_spec = kv_cache_spec
|
||||||
self.block_pool = block_pool
|
self.block_pool = block_pool
|
||||||
self.enable_caching = enable_caching
|
self.enable_caching = enable_caching
|
||||||
|
self.new_block_ids: list[int] = []
|
||||||
|
|
||||||
# Mapping from request ID to blocks to track the blocks allocated
|
# Mapping from request ID to blocks to track the blocks allocated
|
||||||
# for each request, so that we can free the blocks when the request
|
# for each request, so that we can free the blocks when the request
|
||||||
@@ -208,6 +209,8 @@ class SingleTypeKVCacheManager(ABC):
|
|||||||
cdiv(num_total_computed_tokens, self.block_size) - len(req_blocks)
|
cdiv(num_total_computed_tokens, self.block_size) - len(req_blocks)
|
||||||
)
|
)
|
||||||
req_blocks.extend(allocated_blocks)
|
req_blocks.extend(allocated_blocks)
|
||||||
|
if type(self.kv_cache_spec) is FullAttentionSpec:
|
||||||
|
self.new_block_ids.extend(b.block_id for b in allocated_blocks)
|
||||||
|
|
||||||
def allocate_new_blocks(
|
def allocate_new_blocks(
|
||||||
self, request_id: str, num_tokens: int, num_tokens_main_model: int
|
self, request_id: str, num_tokens: int, num_tokens_main_model: int
|
||||||
@@ -234,8 +237,16 @@ class SingleTypeKVCacheManager(ABC):
|
|||||||
else:
|
else:
|
||||||
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
|
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
|
||||||
req_blocks.extend(new_blocks)
|
req_blocks.extend(new_blocks)
|
||||||
|
if type(self.kv_cache_spec) is FullAttentionSpec:
|
||||||
|
self.new_block_ids.extend(b.block_id for b in new_blocks)
|
||||||
return new_blocks
|
return new_blocks
|
||||||
|
|
||||||
|
def take_new_block_ids(self) -> list[int]:
|
||||||
|
"""Drain and return block IDs allocated since the last call."""
|
||||||
|
ids = self.new_block_ids
|
||||||
|
self.new_block_ids = []
|
||||||
|
return ids
|
||||||
|
|
||||||
def cache_blocks(self, request: Request, num_tokens: int) -> None:
|
def cache_blocks(self, request: Request, num_tokens: int) -> None:
|
||||||
"""
|
"""
|
||||||
Cache the blocks for the request.
|
Cache the blocks for the request.
|
||||||
|
|||||||
@@ -489,3 +489,11 @@ class KVCacheConfig:
|
|||||||
For models with multiple types of attention, there will be multiple groups,
|
For models with multiple types of attention, there will be multiple groups,
|
||||||
see `_get_kv_cache_config_uniform_page_size` for more details.
|
see `_get_kv_cache_config_uniform_page_size` for more details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_mamba_layers(self) -> bool:
|
||||||
|
return any(isinstance(g.kv_cache_spec, MambaSpec) for g in self.kv_cache_groups)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def needs_kv_cache_zeroing(self) -> bool:
|
||||||
|
return self.has_mamba_layers
|
||||||
|
|||||||
@@ -187,6 +187,7 @@ from vllm.v1.worker.workspace import lock_workspace
|
|||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AttentionGroup,
|
AttentionGroup,
|
||||||
|
KVBlockZeroer,
|
||||||
add_kv_sharing_layers_to_kv_cache_groups,
|
add_kv_sharing_layers_to_kv_cache_groups,
|
||||||
bind_kv_cache,
|
bind_kv_cache,
|
||||||
prepare_kernel_block_sizes,
|
prepare_kernel_block_sizes,
|
||||||
@@ -918,6 +919,26 @@ class GPUModelRunner(
|
|||||||
decode_threshold=self.reorder_batch_threshold,
|
decode_threshold=self.reorder_batch_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _init_kv_zero_meta(self) -> None:
|
||||||
|
"""One-time precomputation for _zero_block_ids.
|
||||||
|
|
||||||
|
Delegates to KVBlockZeroer.init_meta with the runner's state.
|
||||||
|
Called from gpu_worker.py outside the CuMem pool context.
|
||||||
|
"""
|
||||||
|
self._kv_block_zeroer = KVBlockZeroer(self.device, self.pin_memory)
|
||||||
|
self._kv_block_zeroer.init_meta(
|
||||||
|
attn_groups_iter=self._kv_cache_spec_attn_group_iterator(),
|
||||||
|
kernel_block_sizes=self._kernel_block_sizes,
|
||||||
|
cache_dtype=self.cache_config.cache_dtype,
|
||||||
|
runner_only_attn_layers=self.runner_only_attn_layers,
|
||||||
|
static_forward_context=(self.compilation_config.static_forward_context),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _zero_block_ids(self, block_ids: list[int]) -> None:
|
||||||
|
"""Zero the KV cache memory for the given block IDs."""
|
||||||
|
if hasattr(self, "_kv_block_zeroer"):
|
||||||
|
self._kv_block_zeroer.zero_block_ids(block_ids)
|
||||||
|
|
||||||
# Note: used for model runner override.
|
# Note: used for model runner override.
|
||||||
def _init_device_properties(self) -> None:
|
def _init_device_properties(self) -> None:
|
||||||
"""Initialize attributes from torch.cuda.get_device_properties"""
|
"""Initialize attributes from torch.cuda.get_device_properties"""
|
||||||
@@ -951,6 +972,11 @@ class GPUModelRunner(
|
|||||||
for req_id in scheduler_output.finished_req_ids:
|
for req_id in scheduler_output.finished_req_ids:
|
||||||
self.input_batch.remove_request(req_id)
|
self.input_batch.remove_request(req_id)
|
||||||
|
|
||||||
|
# Zero GPU memory for freshly allocated cache blocks to prevent
|
||||||
|
# stale NaN/data from corrupting attention or SSM computation.
|
||||||
|
if scheduler_output.new_block_ids_to_zero:
|
||||||
|
self._zero_block_ids(scheduler_output.new_block_ids_to_zero)
|
||||||
|
|
||||||
# Free the cached encoder outputs.
|
# Free the cached encoder outputs.
|
||||||
for mm_hash in scheduler_output.free_encoder_mm_hashes:
|
for mm_hash in scheduler_output.free_encoder_mm_hashes:
|
||||||
self.encoder_cache.pop(mm_hash, None)
|
self.encoder_cache.pop(mm_hash, None)
|
||||||
@@ -6066,6 +6092,7 @@ class GPUModelRunner(
|
|||||||
kernel_block_sizes = prepare_kernel_block_sizes(
|
kernel_block_sizes = prepare_kernel_block_sizes(
|
||||||
kv_cache_config, self.attn_groups
|
kv_cache_config, self.attn_groups
|
||||||
)
|
)
|
||||||
|
self._kernel_block_sizes = kernel_block_sizes
|
||||||
|
|
||||||
# create metadata builders
|
# create metadata builders
|
||||||
self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes)
|
self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes)
|
||||||
|
|||||||
@@ -480,6 +480,14 @@ class Worker(WorkerBase):
|
|||||||
else:
|
else:
|
||||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||||
|
|
||||||
|
# Build KV-zero metadata outside the CuMem pool so the bookkeeping
|
||||||
|
# GPU tensors (seg_addrs, block-id buffers) use the standard PyTorch
|
||||||
|
# allocator and are not discarded during sleep/wake cycles.
|
||||||
|
if kv_cache_config.needs_kv_cache_zeroing and hasattr(
|
||||||
|
self.model_runner, "_init_kv_zero_meta"
|
||||||
|
):
|
||||||
|
self.model_runner._init_kv_zero_meta()
|
||||||
|
|
||||||
@instrument(span_name="Warmup (GPU)")
|
@instrument(span_name="Warmup (GPU)")
|
||||||
def compile_or_warm_up_model(self) -> float:
|
def compile_or_warm_up_model(self) -> float:
|
||||||
warmup_sizes = []
|
warmup_sizes = []
|
||||||
|
|||||||
@@ -2,7 +2,10 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import math
|
import math
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from itertools import product as iprod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -12,6 +15,8 @@ from vllm.model_executor.layers.attention import Attention
|
|||||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||||
from vllm.model_executor.models.utils import extract_layer_index
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
from vllm.utils.math_utils import largest_power_of_2_divisor
|
||||||
from vllm.utils.mem_utils import MemorySnapshot, format_gib
|
from vllm.utils.mem_utils import MemorySnapshot, format_gib
|
||||||
from vllm.v1.attention.backend import (
|
from vllm.v1.attention.backend import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
@@ -21,6 +26,7 @@ from vllm.v1.attention.backend import (
|
|||||||
from vllm.v1.kv_cache_interface import (
|
from vllm.v1.kv_cache_interface import (
|
||||||
AttentionSpec,
|
AttentionSpec,
|
||||||
EncoderOnlyAttentionSpec,
|
EncoderOnlyAttentionSpec,
|
||||||
|
FullAttentionSpec,
|
||||||
KVCacheConfig,
|
KVCacheConfig,
|
||||||
KVCacheGroupSpec,
|
KVCacheGroupSpec,
|
||||||
KVCacheSpec,
|
KVCacheSpec,
|
||||||
@@ -31,6 +37,186 @@ from vllm.v1.kv_cache_interface import (
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _zero_kv_blocks_kernel(
|
||||||
|
seg_addrs_ptr,
|
||||||
|
block_ids_ptr,
|
||||||
|
n_blocks,
|
||||||
|
N_SEGS: tl.constexpr,
|
||||||
|
PAGE_SIZE_EL: tl.constexpr,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""Zero KV cache blocks across all segments in a single launch.
|
||||||
|
|
||||||
|
Each segment is a contiguous region of one block's data. For backends
|
||||||
|
where blocks are outermost (block_dim=0) there is one segment per
|
||||||
|
buffer. For backends where K/V is outermost (block_dim=1) there are
|
||||||
|
two segments per buffer (one for K, one for V).
|
||||||
|
|
||||||
|
seg_addrs_ptr holds absolute byte addresses (int64) for each segment,
|
||||||
|
allowing segments to live in different CUDA allocations.
|
||||||
|
|
||||||
|
Programs are mapped as (block_index, seg_index, chunk_index).
|
||||||
|
"""
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
chunks = PAGE_SIZE_EL // BLOCK_SIZE
|
||||||
|
work_per_block = N_SEGS * chunks
|
||||||
|
block_index = pid // work_per_block
|
||||||
|
if block_index >= n_blocks:
|
||||||
|
return
|
||||||
|
remainder = pid % work_per_block
|
||||||
|
seg_index = remainder // chunks
|
||||||
|
chunk_index = remainder % chunks
|
||||||
|
block_id = tl.load(block_ids_ptr + block_index)
|
||||||
|
seg_addr = tl.load(seg_addrs_ptr + seg_index)
|
||||||
|
ptr = tl.cast(seg_addr, tl.pointer_type(tl.int32))
|
||||||
|
offset = (
|
||||||
|
block_id.to(tl.int64) * PAGE_SIZE_EL + chunk_index.to(tl.int64) * BLOCK_SIZE
|
||||||
|
)
|
||||||
|
cols = tl.arange(0, BLOCK_SIZE).to(tl.int64)
|
||||||
|
tl.store(ptr + offset + cols, tl.zeros([BLOCK_SIZE], dtype=tl.int32))
|
||||||
|
|
||||||
|
|
||||||
|
class KVBlockZeroer:
|
||||||
|
"""Manages efficient zeroing of KV cache blocks via a Triton kernel.
|
||||||
|
|
||||||
|
Call :meth:`init_meta` once after KV caches are allocated to precompute
|
||||||
|
segment addresses, then call :meth:`zero_block_ids` each step to zero
|
||||||
|
newly-allocated blocks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device: torch.device, pin_memory: bool):
|
||||||
|
self.device = device
|
||||||
|
self.pin_memory = pin_memory
|
||||||
|
self._meta: tuple[torch.Tensor, int, int, int] | None = None
|
||||||
|
self._id_cap: int = 0
|
||||||
|
self._ids_pinned: torch.Tensor | None = None
|
||||||
|
self._ids_gpu: torch.Tensor | None = None
|
||||||
|
|
||||||
|
def init_meta(
|
||||||
|
self,
|
||||||
|
attn_groups_iter: Iterable["AttentionGroup"],
|
||||||
|
kernel_block_sizes: list[int],
|
||||||
|
cache_dtype: str,
|
||||||
|
runner_only_attn_layers: set[str],
|
||||||
|
static_forward_context: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""One-time precomputation for zero_block_ids.
|
||||||
|
|
||||||
|
Builds absolute-address table for the Triton zeroing kernel.
|
||||||
|
Each entry is the absolute byte address of a segment start on the
|
||||||
|
GPU, so segments in different CUDA allocations work correctly.
|
||||||
|
|
||||||
|
Block IDs from the scheduler reference logical blocks whose size
|
||||||
|
may differ from the kernel block size (virtual block splitting).
|
||||||
|
PAGE_SIZE_EL accounts for this ratio so that
|
||||||
|
``block_id * PAGE_SIZE_EL`` lands at the correct offset.
|
||||||
|
|
||||||
|
Only AttentionSpec layers are processed; Mamba layers are skipped.
|
||||||
|
"""
|
||||||
|
seen_ptrs: set[int] = set()
|
||||||
|
seg_addrs: list[int] = []
|
||||||
|
page_size_el: int | None = None
|
||||||
|
|
||||||
|
for group in attn_groups_iter:
|
||||||
|
spec = group.kv_cache_spec
|
||||||
|
if type(spec) is not FullAttentionSpec:
|
||||||
|
continue
|
||||||
|
if group.kv_cache_group_id >= len(kernel_block_sizes):
|
||||||
|
continue
|
||||||
|
kernel_bs = kernel_block_sizes[group.kv_cache_group_id]
|
||||||
|
ratio = spec.block_size // kernel_bs
|
||||||
|
block_dim = group.backend.get_kv_cache_block_dim(
|
||||||
|
kernel_bs,
|
||||||
|
spec.num_kv_heads,
|
||||||
|
spec.head_size,
|
||||||
|
cache_dtype_str=cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
for layer_name in group.layer_names:
|
||||||
|
if layer_name in runner_only_attn_layers:
|
||||||
|
continue
|
||||||
|
kv = static_forward_context[layer_name].kv_cache[0]
|
||||||
|
if isinstance(kv, list):
|
||||||
|
continue
|
||||||
|
dp = kv.data_ptr()
|
||||||
|
if dp in seen_ptrs:
|
||||||
|
continue
|
||||||
|
seen_ptrs.add(dp)
|
||||||
|
|
||||||
|
el = kv.element_size()
|
||||||
|
cur_bytes = kv.stride(block_dim) * el
|
||||||
|
assert cur_bytes % 4 == 0
|
||||||
|
kernel_block_el = cur_bytes // 4
|
||||||
|
cur_page_el = kernel_block_el * ratio
|
||||||
|
if page_size_el is None:
|
||||||
|
page_size_el = cur_page_el
|
||||||
|
else:
|
||||||
|
assert page_size_el == cur_page_el, (
|
||||||
|
f"Non-uniform page sizes: {page_size_el} vs {cur_page_el}"
|
||||||
|
)
|
||||||
|
|
||||||
|
block_stride_bytes = cur_bytes
|
||||||
|
outer_dims = [
|
||||||
|
d
|
||||||
|
for d in range(block_dim)
|
||||||
|
if kv.stride(d) * el > block_stride_bytes
|
||||||
|
]
|
||||||
|
outer_strides = [kv.stride(d) * el for d in outer_dims]
|
||||||
|
for outer in iprod(*(range(kv.shape[d]) for d in outer_dims)):
|
||||||
|
off_bytes = sum(i * s for i, s in zip(outer, outer_strides))
|
||||||
|
seg_addrs.append(dp + off_bytes)
|
||||||
|
|
||||||
|
if not seg_addrs or page_size_el is None:
|
||||||
|
self._meta = None
|
||||||
|
return
|
||||||
|
|
||||||
|
blk_size = min(largest_power_of_2_divisor(page_size_el), 1024)
|
||||||
|
self._id_cap = 8192
|
||||||
|
self._ids_pinned = torch.empty(
|
||||||
|
self._id_cap,
|
||||||
|
dtype=torch.int64,
|
||||||
|
pin_memory=self.pin_memory,
|
||||||
|
)
|
||||||
|
self._ids_gpu = torch.empty(self._id_cap, dtype=torch.int64, device=self.device)
|
||||||
|
self._meta = (
|
||||||
|
torch.tensor(seg_addrs, dtype=torch.int64, device=self.device),
|
||||||
|
page_size_el,
|
||||||
|
blk_size,
|
||||||
|
len(seg_addrs),
|
||||||
|
)
|
||||||
|
|
||||||
|
def zero_block_ids(self, block_ids: list[int]) -> None:
|
||||||
|
"""Zero the KV cache memory for the given block IDs."""
|
||||||
|
if not block_ids or self._meta is None:
|
||||||
|
return
|
||||||
|
seg_addrs, page_size_el, blk_size, n_segs = self._meta
|
||||||
|
n_blocks = len(block_ids)
|
||||||
|
if n_blocks > self._id_cap:
|
||||||
|
self._id_cap = n_blocks * 2
|
||||||
|
self._ids_pinned = torch.empty(
|
||||||
|
self._id_cap,
|
||||||
|
dtype=torch.int64,
|
||||||
|
pin_memory=self.pin_memory,
|
||||||
|
)
|
||||||
|
self._ids_gpu = torch.empty(
|
||||||
|
self._id_cap, dtype=torch.int64, device=self.device
|
||||||
|
)
|
||||||
|
assert self._ids_pinned is not None and self._ids_gpu is not None
|
||||||
|
self._ids_pinned[:n_blocks].numpy()[:] = block_ids
|
||||||
|
idx = self._ids_gpu[:n_blocks]
|
||||||
|
idx.copy_(self._ids_pinned[:n_blocks], non_blocking=True)
|
||||||
|
grid = (n_blocks * n_segs * (page_size_el // blk_size),)
|
||||||
|
_zero_kv_blocks_kernel[grid](
|
||||||
|
seg_addrs,
|
||||||
|
idx,
|
||||||
|
n_blocks,
|
||||||
|
N_SEGS=n_segs,
|
||||||
|
PAGE_SIZE_EL=page_size_el,
|
||||||
|
BLOCK_SIZE=blk_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AttentionGroup:
|
class AttentionGroup:
|
||||||
backend: type[AttentionBackend]
|
backend: type[AttentionBackend]
|
||||||
|
|||||||
Reference in New Issue
Block a user