Improve enable chunked_prefill & prefix_caching logic. (#26623)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -105,8 +105,6 @@ def test_embed_models(
|
|||||||
def test_non_causal_models(
|
def test_non_causal_models(
|
||||||
hf_runner, vllm_runner, example_prompts, model: str, dtype: str
|
hf_runner, vllm_runner, example_prompts, model: str, dtype: str
|
||||||
) -> None:
|
) -> None:
|
||||||
with vllm_runner(
|
with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model:
|
||||||
model, max_model_len=512, dtype=dtype, enable_prefix_caching=True
|
|
||||||
) as vllm_model:
|
|
||||||
cache_config = vllm_model.llm.llm_engine.cache_config
|
cache_config = vllm_model.llm.llm_engine.cache_config
|
||||||
assert not cache_config.enable_prefix_caching
|
assert not cache_config.enable_prefix_caching
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import MISSING, Field, asdict, dataclass, field
|
from dataclasses import MISSING, Field, asdict, dataclass, field
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
@@ -602,6 +602,244 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
|
|||||||
assert os.path.exists(config2.tokenizer) and os.path.isdir(config2.tokenizer)
|
assert os.path.exists(config2.tokenizer) and os.path.isdir(config2.tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("model_id", "expected_attn_type", "expected_result", "reason"),
|
||||||
|
[
|
||||||
|
# pooling models
|
||||||
|
(
|
||||||
|
"jason9693/Qwen2.5-1.5B-apeach",
|
||||||
|
"decoder",
|
||||||
|
True,
|
||||||
|
"Pooling models with causal attn and last pooling support chunked prefill.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"Qwen/Qwen3-Embedding-0.6B",
|
||||||
|
"decoder",
|
||||||
|
True,
|
||||||
|
"Pooling models with causal attn and last pooling support chunked prefill.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"Qwen/Qwen2.5-Math-PRM-7B",
|
||||||
|
"decoder",
|
||||||
|
False,
|
||||||
|
"Pooling models with step pooling does not support chunked prefill.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"internlm/internlm2-1_8b-reward",
|
||||||
|
"decoder",
|
||||||
|
False,
|
||||||
|
"Pooling models with all pooling does not support chunked prefill.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"BAAI/bge-base-en",
|
||||||
|
"encoder_only",
|
||||||
|
False,
|
||||||
|
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"boltuix/NeuroBERT-NER",
|
||||||
|
"encoder_only",
|
||||||
|
False,
|
||||||
|
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"papluca/xlm-roberta-base-language-detection",
|
||||||
|
"encoder_only",
|
||||||
|
False,
|
||||||
|
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||||
|
"encoder_only",
|
||||||
|
False,
|
||||||
|
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"intfloat/e5-small",
|
||||||
|
"encoder_only",
|
||||||
|
False,
|
||||||
|
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||||
|
),
|
||||||
|
# multimodal models
|
||||||
|
(
|
||||||
|
"openai/clip-vit-base-patch32",
|
||||||
|
"decoder",
|
||||||
|
True,
|
||||||
|
"Pooling models with causal attn and last pooling support chunked prefill.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"google/siglip-base-patch16-224",
|
||||||
|
"encoder_only",
|
||||||
|
False,
|
||||||
|
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||||
|
),
|
||||||
|
# generate models
|
||||||
|
(
|
||||||
|
"Qwen/Qwen3-0.6B",
|
||||||
|
"decoder",
|
||||||
|
True,
|
||||||
|
"Generative models support chunked prefill.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||||
|
"hybrid",
|
||||||
|
True,
|
||||||
|
"Generative models support chunked prefill.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"ibm-granite/granite-4.0-h-small",
|
||||||
|
"hybrid",
|
||||||
|
True,
|
||||||
|
"Generative models support chunked prefill.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"state-spaces/mamba-130m-hf",
|
||||||
|
"attention_free",
|
||||||
|
True,
|
||||||
|
"Generative models support chunked prefill.",
|
||||||
|
),
|
||||||
|
# encoder_decoder models
|
||||||
|
(
|
||||||
|
"openai/whisper-small",
|
||||||
|
"encoder_decoder",
|
||||||
|
False,
|
||||||
|
"Encoder decoder models does not support chunked prefill.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_is_chunked_prefill_supported(
|
||||||
|
model_id: str,
|
||||||
|
expected_attn_type: str,
|
||||||
|
expected_result: bool,
|
||||||
|
reason: str,
|
||||||
|
caplog_vllm,
|
||||||
|
):
|
||||||
|
model_config = ModelConfig(model_id, trust_remote_code=True)
|
||||||
|
assert model_config.attn_type == expected_attn_type
|
||||||
|
with caplog_vllm.at_level(level=logging.DEBUG):
|
||||||
|
assert model_config.is_chunked_prefill_supported == expected_result
|
||||||
|
assert reason in caplog_vllm.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("model_id", "expected_attn_type", "expected_result", "reason"),
|
||||||
|
[
|
||||||
|
# pooling models
|
||||||
|
(
|
||||||
|
"jason9693/Qwen2.5-1.5B-apeach",
|
||||||
|
"decoder",
|
||||||
|
True,
|
||||||
|
"Pooling models with causal attn and last pooling support prefix caching.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"Qwen/Qwen3-Embedding-0.6B",
|
||||||
|
"decoder",
|
||||||
|
True,
|
||||||
|
"Pooling models with causal attn and last pooling support prefix caching.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"Qwen/Qwen2.5-Math-PRM-7B",
|
||||||
|
"decoder",
|
||||||
|
False,
|
||||||
|
"Pooling models with step pooling does not support prefix caching.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"internlm/internlm2-1_8b-reward",
|
||||||
|
"decoder",
|
||||||
|
False,
|
||||||
|
"Pooling models with all pooling does not support prefix caching.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"BAAI/bge-base-en",
|
||||||
|
"encoder_only",
|
||||||
|
False,
|
||||||
|
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"boltuix/NeuroBERT-NER",
|
||||||
|
"encoder_only",
|
||||||
|
False,
|
||||||
|
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"papluca/xlm-roberta-base-language-detection",
|
||||||
|
"encoder_only",
|
||||||
|
False,
|
||||||
|
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||||
|
"encoder_only",
|
||||||
|
False,
|
||||||
|
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"intfloat/e5-small",
|
||||||
|
"encoder_only",
|
||||||
|
False,
|
||||||
|
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||||
|
),
|
||||||
|
# multimodal models
|
||||||
|
(
|
||||||
|
"openai/clip-vit-base-patch32",
|
||||||
|
"decoder",
|
||||||
|
True,
|
||||||
|
"Pooling models with causal attn and last pooling support prefix caching.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"google/siglip-base-patch16-224",
|
||||||
|
"encoder_only",
|
||||||
|
False,
|
||||||
|
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||||
|
),
|
||||||
|
# generate models
|
||||||
|
(
|
||||||
|
"Qwen/Qwen3-0.6B",
|
||||||
|
"decoder",
|
||||||
|
True,
|
||||||
|
"Generative models support prefix caching.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||||
|
"hybrid",
|
||||||
|
False,
|
||||||
|
"Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"ibm-granite/granite-4.0-h-small",
|
||||||
|
"hybrid",
|
||||||
|
False,
|
||||||
|
"Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"state-spaces/mamba-130m-hf",
|
||||||
|
"attention_free",
|
||||||
|
False,
|
||||||
|
"Attention free models does not support prefix caching since the feature is still experimental.", # noqa: E501
|
||||||
|
),
|
||||||
|
# encoder_decoder models
|
||||||
|
(
|
||||||
|
"openai/whisper-small",
|
||||||
|
"encoder_decoder",
|
||||||
|
False,
|
||||||
|
"Encoder decoder models does not support prefix caching.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_is_prefix_caching_supported(
|
||||||
|
model_id: str,
|
||||||
|
expected_attn_type: str,
|
||||||
|
expected_result: bool,
|
||||||
|
reason: str,
|
||||||
|
caplog_vllm,
|
||||||
|
):
|
||||||
|
model_config = ModelConfig(model_id, trust_remote_code=True)
|
||||||
|
assert model_config.attn_type == expected_attn_type
|
||||||
|
with caplog_vllm.at_level(level=logging.DEBUG):
|
||||||
|
assert model_config.is_prefix_caching_supported == expected_result
|
||||||
|
assert reason in caplog_vllm.text
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("backend", "custom_ops", "expected"),
|
("backend", "custom_ops", "expected"),
|
||||||
[
|
[
|
||||||
|
|||||||
@@ -107,6 +107,10 @@ _RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = {
|
|||||||
"draft": [],
|
"draft": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AttnTypeStr = Literal[
|
||||||
|
"decoder", "encoder", "encoder_only", "encoder_decoder", "attention_free", "hybrid"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||||
@@ -1752,6 +1756,111 @@ class ModelConfig:
|
|||||||
logger.info("Using max model len %s", max_model_len)
|
logger.info("Using max model len %s", max_model_len)
|
||||||
return max_model_len
|
return max_model_len
|
||||||
|
|
||||||
|
@property
|
||||||
|
def attn_type(self) -> AttnTypeStr:
|
||||||
|
if self.pooler_config is not None:
|
||||||
|
pooling_type = self._model_info.default_pooling_type.lower()
|
||||||
|
if pooling_type == "cls":
|
||||||
|
return "encoder_only"
|
||||||
|
else:
|
||||||
|
is_causal = getattr(self.hf_config, "is_causal", True)
|
||||||
|
return "encoder_only" if not is_causal else self._model_info.attn_type
|
||||||
|
elif self.is_hybrid:
|
||||||
|
return "hybrid"
|
||||||
|
elif self.is_attention_free:
|
||||||
|
return "attention_free"
|
||||||
|
elif self.is_encoder_decoder:
|
||||||
|
return "encoder_decoder"
|
||||||
|
else:
|
||||||
|
return "decoder"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_chunked_prefill_supported(self) -> bool:
|
||||||
|
attn_type = self.attn_type
|
||||||
|
if self.pooler_config is not None:
|
||||||
|
# for pooling models
|
||||||
|
if attn_type == "encoder_only":
|
||||||
|
logger.debug(
|
||||||
|
"Pooling models with bidirectional attn does not support "
|
||||||
|
"chunked prefill."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
elif attn_type == "decoder":
|
||||||
|
pooling_type = self.pooler_config.pooling_type.lower()
|
||||||
|
if pooling_type in ["all", "mean", "step", "cls"]:
|
||||||
|
logger.debug(
|
||||||
|
"Pooling models with %s pooling does not "
|
||||||
|
"support chunked prefill.",
|
||||||
|
pooling_type,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# pooling_type == "last"
|
||||||
|
logger.debug(
|
||||||
|
"Pooling models with causal attn and last pooling support "
|
||||||
|
"chunked prefill."
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
# vllm currently does not have pooling models using hybrid,
|
||||||
|
# attention_free or encoder_decoder attn types.
|
||||||
|
return attn_type != "encoder_decoder"
|
||||||
|
else:
|
||||||
|
if attn_type == "encoder_decoder":
|
||||||
|
logger.debug("Encoder decoder models does not support chunked prefill.")
|
||||||
|
return False
|
||||||
|
logger.debug("Generative models support chunked prefill.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_prefix_caching_supported(self) -> bool:
|
||||||
|
attn_type = self.attn_type
|
||||||
|
if self.pooler_config is not None:
|
||||||
|
# for pooling models
|
||||||
|
if attn_type == "encoder_only":
|
||||||
|
logger.debug(
|
||||||
|
"Pooling models with bidirectional attn does not "
|
||||||
|
"support prefix caching."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
elif attn_type == "decoder":
|
||||||
|
pooling_type = self.pooler_config.pooling_type.lower()
|
||||||
|
if pooling_type in ["all", "mean", "step", "cls"]:
|
||||||
|
logger.debug(
|
||||||
|
"Pooling models with %s pooling does not "
|
||||||
|
"support prefix caching.",
|
||||||
|
pooling_type,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# pooling_type == "last"
|
||||||
|
logger.debug(
|
||||||
|
"Pooling models with causal attn and last pooling support "
|
||||||
|
"prefix caching."
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
# vllm currently does not have pooling models using hybrid,
|
||||||
|
# attention_free or encoder_decoder attn types.
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
if attn_type == "hybrid":
|
||||||
|
logger.debug(
|
||||||
|
"Hybrid models does not support prefix caching since the feature "
|
||||||
|
"is still experimental."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
elif attn_type == "attention_free":
|
||||||
|
logger.debug(
|
||||||
|
"Attention free models does not support prefix caching since the "
|
||||||
|
"feature is still experimental."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
elif attn_type == "encoder_decoder":
|
||||||
|
logger.debug("Encoder decoder models does not support prefix caching.")
|
||||||
|
return False
|
||||||
|
else: # attn_type == "decoder"
|
||||||
|
logger.debug("Generative models support prefix caching.")
|
||||||
|
return True
|
||||||
|
|
||||||
def is_model_moe(
|
def is_model_moe(
|
||||||
self,
|
self,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
@@ -11,13 +11,15 @@ from vllm.utils.hashing import safe_hash
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
PoolingTypeStr = Literal["LAST", "ALL", "CLS", "STEP", "MEAN"]
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@dataclass
|
@dataclass
|
||||||
class PoolerConfig:
|
class PoolerConfig:
|
||||||
"""Controls the behavior of output pooling in pooling models."""
|
"""Controls the behavior of output pooling in pooling models."""
|
||||||
|
|
||||||
pooling_type: str | None = None
|
pooling_type: PoolingTypeStr | None = None
|
||||||
"""
|
"""
|
||||||
The pooling method of the pooling model. This should be a key in
|
The pooling method of the pooling model. This should be a key in
|
||||||
[`vllm.model_executor.layers.pooler.PoolingType`][].
|
[`vllm.model_executor.layers.pooler.PoolingType`][].
|
||||||
|
|||||||
@@ -721,65 +721,27 @@ class VllmConfig:
|
|||||||
"correctness and to realize prefill savings. "
|
"correctness and to realize prefill savings. "
|
||||||
)
|
)
|
||||||
|
|
||||||
disable_chunked_prefill_reasons: list[str] = []
|
if self.model_config and self.model_config.is_encoder_decoder:
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
|
||||||
if self.model_config:
|
self.scheduler_config.max_num_encoder_input_tokens = (
|
||||||
if self.model_config.pooler_config:
|
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
|
||||||
pooling_type = self.model_config.pooler_config.pooling_type
|
|
||||||
if pooling_type is None or pooling_type.lower() != "last":
|
|
||||||
disable_chunked_prefill_reasons.append(
|
|
||||||
'Only "last" pooling supports chunked '
|
|
||||||
"prefill and prefix caching; disabling both."
|
|
||||||
)
|
|
||||||
if not getattr(self.model_config.hf_config, "is_causal", True):
|
|
||||||
disable_chunked_prefill_reasons.append(
|
|
||||||
"Only models using causal attention support chunked "
|
|
||||||
"prefill and prefix caching; disabling both."
|
|
||||||
)
|
|
||||||
elif self.model_config.is_encoder_decoder:
|
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
||||||
|
|
||||||
self.scheduler_config.max_num_encoder_input_tokens = (
|
|
||||||
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"Encoder-decoder model detected: setting "
|
|
||||||
"`max_num_encoder_input_tokens` to encoder length (%s)",
|
|
||||||
self.scheduler_config.max_num_encoder_input_tokens,
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
self.model_config.architecture == "WhisperForConditionalGeneration"
|
|
||||||
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
"Whisper is known to have issues with "
|
|
||||||
"forked workers. If startup is hanging, "
|
|
||||||
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
|
|
||||||
"to 'spawn'."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Final off-switch for CP/APC:
|
|
||||||
# Disable for (a) collected blockers, (b) encoder–decoder, or
|
|
||||||
# (c) explicit CP=False when APC wasn't requested.
|
|
||||||
# Do NOT disable merely because the resolved CP flag is False.
|
|
||||||
apc_requested = (
|
|
||||||
self.cache_config is not None and self.cache_config.enable_prefix_caching
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
disable_chunked_prefill_reasons
|
|
||||||
or (self.model_config is not None and self.model_config.is_encoder_decoder)
|
|
||||||
or (
|
|
||||||
self.scheduler_config.enable_chunked_prefill is False
|
|
||||||
and not apc_requested
|
|
||||||
)
|
)
|
||||||
):
|
logger.debug(
|
||||||
for reason in disable_chunked_prefill_reasons:
|
"Encoder-decoder model detected: setting "
|
||||||
logger.info(reason)
|
"`max_num_encoder_input_tokens` to encoder length (%s)",
|
||||||
self.scheduler_config.enable_chunked_prefill = False
|
self.scheduler_config.max_num_encoder_input_tokens,
|
||||||
self.scheduler_config.long_prefill_token_threshold = 0
|
)
|
||||||
|
if (
|
||||||
if self.cache_config is not None:
|
self.model_config.architecture == "WhisperForConditionalGeneration"
|
||||||
self.cache_config.enable_prefix_caching = False
|
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"Whisper is known to have issues with "
|
||||||
|
"forked workers. If startup is hanging, "
|
||||||
|
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
|
||||||
|
"to 'spawn'."
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.kv_events_config is not None
|
self.kv_events_config is not None
|
||||||
|
|||||||
@@ -1349,30 +1349,10 @@ class EngineArgs:
|
|||||||
self.tokenizer = model_config.tokenizer
|
self.tokenizer = model_config.tokenizer
|
||||||
|
|
||||||
self._check_feature_supported(model_config)
|
self._check_feature_supported(model_config)
|
||||||
|
self._set_default_chunked_prefill_and_prefix_caching_args(model_config)
|
||||||
# Set default arguments for V1 Engine.
|
self._set_default_max_num_seqs_and_batched_tokens_args(
|
||||||
self._set_default_args(usage_context, model_config)
|
usage_context, model_config
|
||||||
# Disable chunked prefill and prefix caching for:
|
)
|
||||||
# POWER (ppc64le)/s390x/RISCV CPUs in V1
|
|
||||||
if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
|
|
||||||
CpuArchEnum.POWERPC,
|
|
||||||
CpuArchEnum.S390X,
|
|
||||||
CpuArchEnum.RISCV,
|
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
"Chunked prefill is not supported for ARM and POWER, "
|
|
||||||
"S390X and RISC-V CPUs; "
|
|
||||||
"disabling it for V1 backend."
|
|
||||||
)
|
|
||||||
self.enable_chunked_prefill = False
|
|
||||||
logger.info(
|
|
||||||
"Prefix caching is not supported for ARM and POWER, "
|
|
||||||
"S390X and RISC-V CPUs; "
|
|
||||||
"disabling it for V1 backend."
|
|
||||||
)
|
|
||||||
self.enable_prefix_caching = False
|
|
||||||
|
|
||||||
assert self.enable_chunked_prefill is not None
|
|
||||||
|
|
||||||
sliding_window: int | None = None
|
sliding_window: int | None = None
|
||||||
if not is_interleaved(model_config.hf_text_config):
|
if not is_interleaved(model_config.hf_text_config):
|
||||||
@@ -1805,34 +1785,6 @@ class EngineArgs:
|
|||||||
)
|
)
|
||||||
_raise_unsupported_error(feature_name=name)
|
_raise_unsupported_error(feature_name=name)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_chunked_prefill_prefix_caching_defaults(
|
|
||||||
cls,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
) -> tuple[bool, bool]:
|
|
||||||
if model_config.runner_type != "pooling":
|
|
||||||
default_chunked_prefill = True
|
|
||||||
|
|
||||||
# Disable prefix caching default for hybrid models and mamba-only
|
|
||||||
# models since the feature is still experimental.
|
|
||||||
default_prefix_caching = not (
|
|
||||||
model_config.is_hybrid or model_config.is_attention_free
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert model_config.pooler_config is not None
|
|
||||||
|
|
||||||
pooling_type = model_config.pooler_config.pooling_type
|
|
||||||
incremental_prefill_supported = (
|
|
||||||
pooling_type is not None
|
|
||||||
and pooling_type.lower() == "last"
|
|
||||||
and getattr(model_config.hf_config, "is_causal", True)
|
|
||||||
)
|
|
||||||
|
|
||||||
default_chunked_prefill = incremental_prefill_supported
|
|
||||||
default_prefix_caching = incremental_prefill_supported
|
|
||||||
|
|
||||||
return default_chunked_prefill, default_prefix_caching
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_batch_defaults(
|
def get_batch_defaults(
|
||||||
cls,
|
cls,
|
||||||
@@ -1916,14 +1868,11 @@ class EngineArgs:
|
|||||||
|
|
||||||
return default_max_num_batched_tokens, default_max_num_seqs
|
return default_max_num_batched_tokens, default_max_num_seqs
|
||||||
|
|
||||||
def _set_default_args(
|
def _set_default_chunked_prefill_and_prefix_caching_args(
|
||||||
self, usage_context: UsageContext, model_config: ModelConfig
|
self, model_config: ModelConfig
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set Default Arguments for V1 Engine."""
|
default_chunked_prefill = model_config.is_chunked_prefill_supported
|
||||||
(
|
default_prefix_caching = model_config.is_prefix_caching_supported
|
||||||
default_chunked_prefill,
|
|
||||||
default_prefix_caching,
|
|
||||||
) = self.get_chunked_prefill_prefix_caching_defaults(model_config)
|
|
||||||
|
|
||||||
if self.prefill_context_parallel_size > 1:
|
if self.prefill_context_parallel_size > 1:
|
||||||
default_chunked_prefill = False
|
default_chunked_prefill = False
|
||||||
@@ -1984,6 +1933,29 @@ class EngineArgs:
|
|||||||
scope="local",
|
scope="local",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Disable chunked prefill and prefix caching for:
|
||||||
|
# POWER (ppc64le)/s390x/RISCV CPUs in V1
|
||||||
|
if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
|
||||||
|
CpuArchEnum.POWERPC,
|
||||||
|
CpuArchEnum.S390X,
|
||||||
|
CpuArchEnum.RISCV,
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"Chunked prefill is not supported for ARM and POWER, "
|
||||||
|
"S390X and RISC-V CPUs; "
|
||||||
|
"disabling it for V1 backend."
|
||||||
|
)
|
||||||
|
self.enable_chunked_prefill = False
|
||||||
|
logger.info(
|
||||||
|
"Prefix caching is not supported for ARM and POWER, "
|
||||||
|
"S390X and RISC-V CPUs; "
|
||||||
|
"disabling it for V1 backend."
|
||||||
|
)
|
||||||
|
self.enable_prefix_caching = False
|
||||||
|
|
||||||
|
def _set_default_max_num_seqs_and_batched_tokens_args(
|
||||||
|
self, usage_context: UsageContext, model_config: ModelConfig
|
||||||
|
):
|
||||||
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
|
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
|
||||||
(
|
(
|
||||||
default_max_num_batched_tokens,
|
default_max_num_batched_tokens,
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from vllm.tasks import PoolingTask
|
|||||||
from vllm.v1.pool.metadata import PoolingMetadata
|
from vllm.v1.pool.metadata import PoolingMetadata
|
||||||
|
|
||||||
from .interfaces import SupportsCrossEncoding, SupportsQuant
|
from .interfaces import SupportsCrossEncoding, SupportsQuant
|
||||||
from .interfaces_base import default_pooling_type
|
from .interfaces_base import attn_type, default_pooling_type
|
||||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||||
|
|
||||||
|
|
||||||
@@ -432,7 +432,6 @@ class BertModel(nn.Module, SupportsQuant):
|
|||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
@default_pooling_type("ALL")
|
|
||||||
class BertPoolingModel(BertModel):
|
class BertPoolingModel(BertModel):
|
||||||
is_pooling_model = True
|
is_pooling_model = True
|
||||||
|
|
||||||
@@ -864,6 +863,7 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@attn_type("encoder_only")
|
||||||
@default_pooling_type("ALL")
|
@default_pooling_type("ALL")
|
||||||
class BertForTokenClassification(nn.Module):
|
class BertForTokenClassification(nn.Module):
|
||||||
is_pooling_model = True
|
is_pooling_model = True
|
||||||
|
|||||||
@@ -19,10 +19,14 @@ from vllm.utils.func_utils import supports_kw
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.config.model import AttnTypeStr
|
||||||
|
from vllm.config.pooler import PoolingTypeStr
|
||||||
from vllm.model_executor.layers.pooler import Pooler
|
from vllm.model_executor.layers.pooler import Pooler
|
||||||
else:
|
else:
|
||||||
VllmConfig = Any
|
VllmConfig = Any
|
||||||
Pooler = Any
|
Pooler = Any
|
||||||
|
PoolingTypeStr = Any
|
||||||
|
AttnTypeStr = Any
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -165,7 +169,7 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
|
|||||||
MRO of your model class.
|
MRO of your model class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
default_pooling_type: ClassVar[str] = "LAST"
|
default_pooling_type: ClassVar[PoolingTypeStr] = "LAST"
|
||||||
"""
|
"""
|
||||||
Indicates the [vllm.config.pooler.PoolerConfig.pooling_type][]
|
Indicates the [vllm.config.pooler.PoolerConfig.pooling_type][]
|
||||||
to use by default.
|
to use by default.
|
||||||
@@ -175,6 +179,17 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
|
|||||||
decorator to conveniently set this field.
|
decorator to conveniently set this field.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
attn_type: ClassVar[AttnTypeStr] = "decoder"
|
||||||
|
"""
|
||||||
|
Indicates the
|
||||||
|
[vllm.config.model.ModelConfig.attn_type][]
|
||||||
|
to use by default.
|
||||||
|
|
||||||
|
You can use the
|
||||||
|
[vllm.model_executor.models.interfaces_base.attn_type][]
|
||||||
|
decorator to conveniently set this field.
|
||||||
|
"""
|
||||||
|
|
||||||
pooler: Pooler
|
pooler: Pooler
|
||||||
"""The pooler is only called on TP rank 0."""
|
"""The pooler is only called on TP rank 0."""
|
||||||
|
|
||||||
@@ -199,7 +214,7 @@ def is_pooling_model(
|
|||||||
_T = TypeVar("_T", bound=type[nn.Module])
|
_T = TypeVar("_T", bound=type[nn.Module])
|
||||||
|
|
||||||
|
|
||||||
def default_pooling_type(pooling_type: str):
|
def default_pooling_type(pooling_type: PoolingTypeStr):
|
||||||
"""Decorator to set `VllmModelForPooling.default_pooling_type`."""
|
"""Decorator to set `VllmModelForPooling.default_pooling_type`."""
|
||||||
|
|
||||||
def func(model: _T) -> _T:
|
def func(model: _T) -> _T:
|
||||||
@@ -209,5 +224,19 @@ def default_pooling_type(pooling_type: str):
|
|||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
def get_default_pooling_type(model: type[object] | object) -> str:
|
def get_default_pooling_type(model: type[object] | object) -> PoolingTypeStr:
|
||||||
return getattr(model, "default_pooling_type", "LAST")
|
return getattr(model, "default_pooling_type", "LAST")
|
||||||
|
|
||||||
|
|
||||||
|
def attn_type(attn_type: AttnTypeStr):
|
||||||
|
"""Decorator to set `VllmModelForPooling.attn_type`."""
|
||||||
|
|
||||||
|
def func(model: _T) -> _T:
|
||||||
|
model.attn_type = attn_type # type: ignore
|
||||||
|
return model
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def get_attn_type(model: type[object] | object) -> AttnTypeStr:
|
||||||
|
return getattr(model, "attn_type", "decoder")
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ from vllm.tasks import PoolingTask
|
|||||||
from vllm.v1.pool.metadata import PoolingMetadata
|
from vllm.v1.pool.metadata import PoolingMetadata
|
||||||
|
|
||||||
from .interfaces import SupportsCrossEncoding
|
from .interfaces import SupportsCrossEncoding
|
||||||
from .interfaces_base import default_pooling_type
|
from .interfaces_base import attn_type, default_pooling_type
|
||||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||||
|
|
||||||
|
|
||||||
@@ -396,6 +396,7 @@ class ModernBertPredictionHead(nn.Module):
|
|||||||
return self.norm(self.act(self.dense(hidden_states)))
|
return self.norm(self.act(self.dense(hidden_states)))
|
||||||
|
|
||||||
|
|
||||||
|
@attn_type("encoder_only")
|
||||||
@default_pooling_type("ALL")
|
@default_pooling_type("ALL")
|
||||||
class ModernBertForTokenClassification(nn.Module):
|
class ModernBertForTokenClassification(nn.Module):
|
||||||
is_pooling_model = True
|
is_pooling_model = True
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from collections.abc import Callable, Set
|
|||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TypeVar
|
from typing import TYPE_CHECKING, Any, TypeVar
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import transformers
|
import transformers
|
||||||
@@ -33,6 +33,14 @@ from vllm.logging_utils import logtime
|
|||||||
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
|
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
|
||||||
from vllm.utils.hashing import safe_hash
|
from vllm.utils.hashing import safe_hash
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.config.model import AttnTypeStr
|
||||||
|
from vllm.config.pooler import PoolingTypeStr
|
||||||
|
else:
|
||||||
|
AttnTypeStr = Any
|
||||||
|
PoolingTypeStr = Any
|
||||||
|
|
||||||
|
|
||||||
from .interfaces import (
|
from .interfaces import (
|
||||||
has_inner_state,
|
has_inner_state,
|
||||||
has_noops,
|
has_noops,
|
||||||
@@ -47,6 +55,7 @@ from .interfaces import (
|
|||||||
supports_transcription,
|
supports_transcription,
|
||||||
)
|
)
|
||||||
from .interfaces_base import (
|
from .interfaces_base import (
|
||||||
|
get_attn_type,
|
||||||
get_default_pooling_type,
|
get_default_pooling_type,
|
||||||
is_pooling_model,
|
is_pooling_model,
|
||||||
is_text_generation_model,
|
is_text_generation_model,
|
||||||
@@ -509,7 +518,8 @@ class _ModelInfo:
|
|||||||
architecture: str
|
architecture: str
|
||||||
is_text_generation_model: bool
|
is_text_generation_model: bool
|
||||||
is_pooling_model: bool
|
is_pooling_model: bool
|
||||||
default_pooling_type: str
|
attn_type: AttnTypeStr
|
||||||
|
default_pooling_type: PoolingTypeStr
|
||||||
supports_cross_encoding: bool
|
supports_cross_encoding: bool
|
||||||
supports_multimodal: bool
|
supports_multimodal: bool
|
||||||
supports_multimodal_raw_input_only: bool
|
supports_multimodal_raw_input_only: bool
|
||||||
@@ -530,6 +540,7 @@ class _ModelInfo:
|
|||||||
is_text_generation_model=is_text_generation_model(model),
|
is_text_generation_model=is_text_generation_model(model),
|
||||||
is_pooling_model=is_pooling_model(model),
|
is_pooling_model=is_pooling_model(model),
|
||||||
default_pooling_type=get_default_pooling_type(model),
|
default_pooling_type=get_default_pooling_type(model),
|
||||||
|
attn_type=get_attn_type(model),
|
||||||
supports_cross_encoding=supports_cross_encoding(model),
|
supports_cross_encoding=supports_cross_encoding(model),
|
||||||
supports_multimodal=supports_multimodal(model),
|
supports_multimodal=supports_multimodal(model),
|
||||||
supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
|
supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
|
||||||
|
|||||||
@@ -119,11 +119,12 @@ class EngineCore:
|
|||||||
# Setup scheduler.
|
# Setup scheduler.
|
||||||
Scheduler = vllm_config.scheduler_config.get_scheduler_cls()
|
Scheduler = vllm_config.scheduler_config.get_scheduler_cls()
|
||||||
|
|
||||||
if len(kv_cache_config.kv_cache_groups) == 0:
|
if len(kv_cache_config.kv_cache_groups) == 0: # noqa: SIM102
|
||||||
# Encoder models without KV cache don't support
|
# Encoder models without KV cache don't support
|
||||||
# chunked prefill. But do SSM models?
|
# chunked prefill. But do SSM models?
|
||||||
logger.info("Disabling chunked prefill for model without KVCache")
|
if vllm_config.scheduler_config.enable_chunked_prefill:
|
||||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
logger.warning("Disabling chunked prefill for model without KVCache")
|
||||||
|
vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||||
|
|
||||||
scheduler_block_size = (
|
scheduler_block_size = (
|
||||||
vllm_config.cache_config.block_size
|
vllm_config.cache_config.block_size
|
||||||
|
|||||||
Reference in New Issue
Block a user