Improve enable chunked_prefill & prefix_caching logic. (#26623)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
wang.yuqi
2025-11-28 14:05:48 +08:00
committed by GitHub
parent 37b15e97e8
commit f4b76056ee
11 changed files with 456 additions and 133 deletions

View File

@@ -105,8 +105,6 @@ def test_embed_models(
def test_non_causal_models( def test_non_causal_models(
hf_runner, vllm_runner, example_prompts, model: str, dtype: str hf_runner, vllm_runner, example_prompts, model: str, dtype: str
) -> None: ) -> None:
with vllm_runner( with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model:
model, max_model_len=512, dtype=dtype, enable_prefix_caching=True
) as vllm_model:
cache_config = vllm_model.llm.llm_engine.cache_config cache_config = vllm_model.llm.llm_engine.cache_config
assert not cache_config.enable_prefix_caching assert not cache_config.enable_prefix_caching

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import os import os
from dataclasses import MISSING, Field, asdict, dataclass, field from dataclasses import MISSING, Field, asdict, dataclass, field
from unittest.mock import patch from unittest.mock import patch
@@ -602,6 +602,244 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
assert os.path.exists(config2.tokenizer) and os.path.isdir(config2.tokenizer) assert os.path.exists(config2.tokenizer) and os.path.isdir(config2.tokenizer)
@pytest.mark.parametrize(
("model_id", "expected_attn_type", "expected_result", "reason"),
[
# pooling models
(
"jason9693/Qwen2.5-1.5B-apeach",
"decoder",
True,
"Pooling models with causal attn and last pooling support chunked prefill.",
),
(
"Qwen/Qwen3-Embedding-0.6B",
"decoder",
True,
"Pooling models with causal attn and last pooling support chunked prefill.",
),
(
"Qwen/Qwen2.5-Math-PRM-7B",
"decoder",
False,
"Pooling models with step pooling does not support chunked prefill.",
),
(
"internlm/internlm2-1_8b-reward",
"decoder",
False,
"Pooling models with all pooling does not support chunked prefill.",
),
(
"BAAI/bge-base-en",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
),
(
"boltuix/NeuroBERT-NER",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
),
(
"papluca/xlm-roberta-base-language-detection",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
),
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
),
(
"intfloat/e5-small",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
),
# multimodal models
(
"openai/clip-vit-base-patch32",
"decoder",
True,
"Pooling models with causal attn and last pooling support chunked prefill.",
),
(
"google/siglip-base-patch16-224",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
),
# generate models
(
"Qwen/Qwen3-0.6B",
"decoder",
True,
"Generative models support chunked prefill.",
),
(
"Qwen/Qwen3-Next-80B-A3B-Instruct",
"hybrid",
True,
"Generative models support chunked prefill.",
),
(
"ibm-granite/granite-4.0-h-small",
"hybrid",
True,
"Generative models support chunked prefill.",
),
(
"state-spaces/mamba-130m-hf",
"attention_free",
True,
"Generative models support chunked prefill.",
),
# encoder_decoder models
(
"openai/whisper-small",
"encoder_decoder",
False,
"Encoder decoder models does not support chunked prefill.",
),
],
)
def test_is_chunked_prefill_supported(
model_id: str,
expected_attn_type: str,
expected_result: bool,
reason: str,
caplog_vllm,
):
model_config = ModelConfig(model_id, trust_remote_code=True)
assert model_config.attn_type == expected_attn_type
with caplog_vllm.at_level(level=logging.DEBUG):
assert model_config.is_chunked_prefill_supported == expected_result
assert reason in caplog_vllm.text
@pytest.mark.parametrize(
("model_id", "expected_attn_type", "expected_result", "reason"),
[
# pooling models
(
"jason9693/Qwen2.5-1.5B-apeach",
"decoder",
True,
"Pooling models with causal attn and last pooling support prefix caching.",
),
(
"Qwen/Qwen3-Embedding-0.6B",
"decoder",
True,
"Pooling models with causal attn and last pooling support prefix caching.",
),
(
"Qwen/Qwen2.5-Math-PRM-7B",
"decoder",
False,
"Pooling models with step pooling does not support prefix caching.",
),
(
"internlm/internlm2-1_8b-reward",
"decoder",
False,
"Pooling models with all pooling does not support prefix caching.",
),
(
"BAAI/bge-base-en",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
),
(
"boltuix/NeuroBERT-NER",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
),
(
"papluca/xlm-roberta-base-language-detection",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
),
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
),
(
"intfloat/e5-small",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
),
# multimodal models
(
"openai/clip-vit-base-patch32",
"decoder",
True,
"Pooling models with causal attn and last pooling support prefix caching.",
),
(
"google/siglip-base-patch16-224",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
),
# generate models
(
"Qwen/Qwen3-0.6B",
"decoder",
True,
"Generative models support prefix caching.",
),
(
"Qwen/Qwen3-Next-80B-A3B-Instruct",
"hybrid",
False,
"Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501
),
(
"ibm-granite/granite-4.0-h-small",
"hybrid",
False,
"Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501
),
(
"state-spaces/mamba-130m-hf",
"attention_free",
False,
"Attention free models does not support prefix caching since the feature is still experimental.", # noqa: E501
),
# encoder_decoder models
(
"openai/whisper-small",
"encoder_decoder",
False,
"Encoder decoder models does not support prefix caching.",
),
],
)
def test_is_prefix_caching_supported(
model_id: str,
expected_attn_type: str,
expected_result: bool,
reason: str,
caplog_vllm,
):
model_config = ModelConfig(model_id, trust_remote_code=True)
assert model_config.attn_type == expected_attn_type
with caplog_vllm.at_level(level=logging.DEBUG):
assert model_config.is_prefix_caching_supported == expected_result
assert reason in caplog_vllm.text
@pytest.mark.parametrize( @pytest.mark.parametrize(
("backend", "custom_ops", "expected"), ("backend", "custom_ops", "expected"),
[ [

View File

@@ -107,6 +107,10 @@ _RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = {
"draft": [], "draft": [],
} }
AttnTypeStr = Literal[
"decoder", "encoder", "encoder_only", "encoder_decoder", "attention_free", "hybrid"
]
@config @config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) @dataclass(config=ConfigDict(arbitrary_types_allowed=True))
@@ -1752,6 +1756,111 @@ class ModelConfig:
logger.info("Using max model len %s", max_model_len) logger.info("Using max model len %s", max_model_len)
return max_model_len return max_model_len
@property
def attn_type(self) -> AttnTypeStr:
if self.pooler_config is not None:
pooling_type = self._model_info.default_pooling_type.lower()
if pooling_type == "cls":
return "encoder_only"
else:
is_causal = getattr(self.hf_config, "is_causal", True)
return "encoder_only" if not is_causal else self._model_info.attn_type
elif self.is_hybrid:
return "hybrid"
elif self.is_attention_free:
return "attention_free"
elif self.is_encoder_decoder:
return "encoder_decoder"
else:
return "decoder"
@property
def is_chunked_prefill_supported(self) -> bool:
attn_type = self.attn_type
if self.pooler_config is not None:
# for pooling models
if attn_type == "encoder_only":
logger.debug(
"Pooling models with bidirectional attn does not support "
"chunked prefill."
)
return False
elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower()
if pooling_type in ["all", "mean", "step", "cls"]:
logger.debug(
"Pooling models with %s pooling does not "
"support chunked prefill.",
pooling_type,
)
return False
else:
# pooling_type == "last"
logger.debug(
"Pooling models with causal attn and last pooling support "
"chunked prefill."
)
return True
# vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types.
return attn_type != "encoder_decoder"
else:
if attn_type == "encoder_decoder":
logger.debug("Encoder decoder models does not support chunked prefill.")
return False
logger.debug("Generative models support chunked prefill.")
return True
@property
def is_prefix_caching_supported(self) -> bool:
attn_type = self.attn_type
if self.pooler_config is not None:
# for pooling models
if attn_type == "encoder_only":
logger.debug(
"Pooling models with bidirectional attn does not "
"support prefix caching."
)
return False
elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower()
if pooling_type in ["all", "mean", "step", "cls"]:
logger.debug(
"Pooling models with %s pooling does not "
"support prefix caching.",
pooling_type,
)
return False
else:
# pooling_type == "last"
logger.debug(
"Pooling models with causal attn and last pooling support "
"prefix caching."
)
return True
# vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types.
return False
else:
if attn_type == "hybrid":
logger.debug(
"Hybrid models does not support prefix caching since the feature "
"is still experimental."
)
return False
elif attn_type == "attention_free":
logger.debug(
"Attention free models does not support prefix caching since the "
"feature is still experimental."
)
return False
elif attn_type == "encoder_decoder":
logger.debug("Encoder decoder models does not support prefix caching.")
return False
else: # attn_type == "decoder"
logger.debug("Generative models support prefix caching.")
return True
def is_model_moe( def is_model_moe(
self, self,
) -> bool: ) -> bool:

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any from typing import Any, Literal
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
@@ -11,13 +11,15 @@ from vllm.utils.hashing import safe_hash
logger = init_logger(__name__) logger = init_logger(__name__)
PoolingTypeStr = Literal["LAST", "ALL", "CLS", "STEP", "MEAN"]
@config @config
@dataclass @dataclass
class PoolerConfig: class PoolerConfig:
"""Controls the behavior of output pooling in pooling models.""" """Controls the behavior of output pooling in pooling models."""
pooling_type: str | None = None pooling_type: PoolingTypeStr | None = None
""" """
The pooling method of the pooling model. This should be a key in The pooling method of the pooling model. This should be a key in
[`vllm.model_executor.layers.pooler.PoolingType`][]. [`vllm.model_executor.layers.pooler.PoolingType`][].

View File

@@ -721,65 +721,27 @@ class VllmConfig:
"correctness and to realize prefill savings. " "correctness and to realize prefill savings. "
) )
disable_chunked_prefill_reasons: list[str] = [] if self.model_config and self.model_config.is_encoder_decoder:
from vllm.multimodal import MULTIMODAL_REGISTRY
if self.model_config: self.scheduler_config.max_num_encoder_input_tokens = (
if self.model_config.pooler_config: MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
pooling_type = self.model_config.pooler_config.pooling_type
if pooling_type is None or pooling_type.lower() != "last":
disable_chunked_prefill_reasons.append(
'Only "last" pooling supports chunked '
"prefill and prefix caching; disabling both."
)
if not getattr(self.model_config.hf_config, "is_causal", True):
disable_chunked_prefill_reasons.append(
"Only models using causal attention support chunked "
"prefill and prefix caching; disabling both."
)
elif self.model_config.is_encoder_decoder:
from vllm.multimodal import MULTIMODAL_REGISTRY
self.scheduler_config.max_num_encoder_input_tokens = (
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
)
logger.debug(
"Encoder-decoder model detected: setting "
"`max_num_encoder_input_tokens` to encoder length (%s)",
self.scheduler_config.max_num_encoder_input_tokens,
)
if (
self.model_config.architecture == "WhisperForConditionalGeneration"
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
):
logger.warning(
"Whisper is known to have issues with "
"forked workers. If startup is hanging, "
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'."
)
# Final off-switch for CP/APC:
# Disable for (a) collected blockers, (b) encoderdecoder, or
# (c) explicit CP=False when APC wasn't requested.
# Do NOT disable merely because the resolved CP flag is False.
apc_requested = (
self.cache_config is not None and self.cache_config.enable_prefix_caching
)
if (
disable_chunked_prefill_reasons
or (self.model_config is not None and self.model_config.is_encoder_decoder)
or (
self.scheduler_config.enable_chunked_prefill is False
and not apc_requested
) )
): logger.debug(
for reason in disable_chunked_prefill_reasons: "Encoder-decoder model detected: setting "
logger.info(reason) "`max_num_encoder_input_tokens` to encoder length (%s)",
self.scheduler_config.enable_chunked_prefill = False self.scheduler_config.max_num_encoder_input_tokens,
self.scheduler_config.long_prefill_token_threshold = 0 )
if (
if self.cache_config is not None: self.model_config.architecture == "WhisperForConditionalGeneration"
self.cache_config.enable_prefix_caching = False and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
):
logger.warning(
"Whisper is known to have issues with "
"forked workers. If startup is hanging, "
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'."
)
if ( if (
self.kv_events_config is not None self.kv_events_config is not None

View File

@@ -1349,30 +1349,10 @@ class EngineArgs:
self.tokenizer = model_config.tokenizer self.tokenizer = model_config.tokenizer
self._check_feature_supported(model_config) self._check_feature_supported(model_config)
self._set_default_chunked_prefill_and_prefix_caching_args(model_config)
# Set default arguments for V1 Engine. self._set_default_max_num_seqs_and_batched_tokens_args(
self._set_default_args(usage_context, model_config) usage_context, model_config
# Disable chunked prefill and prefix caching for: )
# POWER (ppc64le)/s390x/RISCV CPUs in V1
if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
CpuArchEnum.POWERPC,
CpuArchEnum.S390X,
CpuArchEnum.RISCV,
):
logger.info(
"Chunked prefill is not supported for ARM and POWER, "
"S390X and RISC-V CPUs; "
"disabling it for V1 backend."
)
self.enable_chunked_prefill = False
logger.info(
"Prefix caching is not supported for ARM and POWER, "
"S390X and RISC-V CPUs; "
"disabling it for V1 backend."
)
self.enable_prefix_caching = False
assert self.enable_chunked_prefill is not None
sliding_window: int | None = None sliding_window: int | None = None
if not is_interleaved(model_config.hf_text_config): if not is_interleaved(model_config.hf_text_config):
@@ -1805,34 +1785,6 @@ class EngineArgs:
) )
_raise_unsupported_error(feature_name=name) _raise_unsupported_error(feature_name=name)
@classmethod
def get_chunked_prefill_prefix_caching_defaults(
cls,
model_config: ModelConfig,
) -> tuple[bool, bool]:
if model_config.runner_type != "pooling":
default_chunked_prefill = True
# Disable prefix caching default for hybrid models and mamba-only
# models since the feature is still experimental.
default_prefix_caching = not (
model_config.is_hybrid or model_config.is_attention_free
)
else:
assert model_config.pooler_config is not None
pooling_type = model_config.pooler_config.pooling_type
incremental_prefill_supported = (
pooling_type is not None
and pooling_type.lower() == "last"
and getattr(model_config.hf_config, "is_causal", True)
)
default_chunked_prefill = incremental_prefill_supported
default_prefix_caching = incremental_prefill_supported
return default_chunked_prefill, default_prefix_caching
@classmethod @classmethod
def get_batch_defaults( def get_batch_defaults(
cls, cls,
@@ -1916,14 +1868,11 @@ class EngineArgs:
return default_max_num_batched_tokens, default_max_num_seqs return default_max_num_batched_tokens, default_max_num_seqs
def _set_default_args( def _set_default_chunked_prefill_and_prefix_caching_args(
self, usage_context: UsageContext, model_config: ModelConfig self, model_config: ModelConfig
) -> None: ) -> None:
"""Set Default Arguments for V1 Engine.""" default_chunked_prefill = model_config.is_chunked_prefill_supported
( default_prefix_caching = model_config.is_prefix_caching_supported
default_chunked_prefill,
default_prefix_caching,
) = self.get_chunked_prefill_prefix_caching_defaults(model_config)
if self.prefill_context_parallel_size > 1: if self.prefill_context_parallel_size > 1:
default_chunked_prefill = False default_chunked_prefill = False
@@ -1984,6 +1933,29 @@ class EngineArgs:
scope="local", scope="local",
) )
# Disable chunked prefill and prefix caching for:
# POWER (ppc64le)/s390x/RISCV CPUs in V1
if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
CpuArchEnum.POWERPC,
CpuArchEnum.S390X,
CpuArchEnum.RISCV,
):
logger.info(
"Chunked prefill is not supported for ARM and POWER, "
"S390X and RISC-V CPUs; "
"disabling it for V1 backend."
)
self.enable_chunked_prefill = False
logger.info(
"Prefix caching is not supported for ARM and POWER, "
"S390X and RISC-V CPUs; "
"disabling it for V1 backend."
)
self.enable_prefix_caching = False
def _set_default_max_num_seqs_and_batched_tokens_args(
self, usage_context: UsageContext, model_config: ModelConfig
):
world_size = self.pipeline_parallel_size * self.tensor_parallel_size world_size = self.pipeline_parallel_size * self.tensor_parallel_size
( (
default_max_num_batched_tokens, default_max_num_batched_tokens,

View File

@@ -32,7 +32,7 @@ from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding, SupportsQuant from .interfaces import SupportsCrossEncoding, SupportsQuant
from .interfaces_base import default_pooling_type from .interfaces_base import attn_type, default_pooling_type
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
@@ -432,7 +432,6 @@ class BertModel(nn.Module, SupportsQuant):
return loaded_params return loaded_params
@default_pooling_type("ALL")
class BertPoolingModel(BertModel): class BertPoolingModel(BertModel):
is_pooling_model = True is_pooling_model = True
@@ -864,6 +863,7 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
) )
@attn_type("encoder_only")
@default_pooling_type("ALL") @default_pooling_type("ALL")
class BertForTokenClassification(nn.Module): class BertForTokenClassification(nn.Module):
is_pooling_model = True is_pooling_model = True

View File

@@ -19,10 +19,14 @@ from vllm.utils.func_utils import supports_kw
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.model import AttnTypeStr
from vllm.config.pooler import PoolingTypeStr
from vllm.model_executor.layers.pooler import Pooler from vllm.model_executor.layers.pooler import Pooler
else: else:
VllmConfig = Any VllmConfig = Any
Pooler = Any Pooler = Any
PoolingTypeStr = Any
AttnTypeStr = Any
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -165,7 +169,7 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
MRO of your model class. MRO of your model class.
""" """
default_pooling_type: ClassVar[str] = "LAST" default_pooling_type: ClassVar[PoolingTypeStr] = "LAST"
""" """
Indicates the [vllm.config.pooler.PoolerConfig.pooling_type][] Indicates the [vllm.config.pooler.PoolerConfig.pooling_type][]
to use by default. to use by default.
@@ -175,6 +179,17 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
decorator to conveniently set this field. decorator to conveniently set this field.
""" """
attn_type: ClassVar[AttnTypeStr] = "decoder"
"""
Indicates the
[vllm.config.model.ModelConfig.attn_type][]
to use by default.
You can use the
[vllm.model_executor.models.interfaces_base.attn_type][]
decorator to conveniently set this field.
"""
pooler: Pooler pooler: Pooler
"""The pooler is only called on TP rank 0.""" """The pooler is only called on TP rank 0."""
@@ -199,7 +214,7 @@ def is_pooling_model(
_T = TypeVar("_T", bound=type[nn.Module]) _T = TypeVar("_T", bound=type[nn.Module])
def default_pooling_type(pooling_type: str): def default_pooling_type(pooling_type: PoolingTypeStr):
"""Decorator to set `VllmModelForPooling.default_pooling_type`.""" """Decorator to set `VllmModelForPooling.default_pooling_type`."""
def func(model: _T) -> _T: def func(model: _T) -> _T:
@@ -209,5 +224,19 @@ def default_pooling_type(pooling_type: str):
return func return func
def get_default_pooling_type(model: type[object] | object) -> str: def get_default_pooling_type(model: type[object] | object) -> PoolingTypeStr:
return getattr(model, "default_pooling_type", "LAST") return getattr(model, "default_pooling_type", "LAST")
def attn_type(attn_type: AttnTypeStr):
"""Decorator to set `VllmModelForPooling.attn_type`."""
def func(model: _T) -> _T:
model.attn_type = attn_type # type: ignore
return model
return func
def get_attn_type(model: type[object] | object) -> AttnTypeStr:
return getattr(model, "attn_type", "decoder")

View File

@@ -28,7 +28,7 @@ from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding from .interfaces import SupportsCrossEncoding
from .interfaces_base import default_pooling_type from .interfaces_base import attn_type, default_pooling_type
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
@@ -396,6 +396,7 @@ class ModernBertPredictionHead(nn.Module):
return self.norm(self.act(self.dense(hidden_states))) return self.norm(self.act(self.dense(hidden_states)))
@attn_type("encoder_only")
@default_pooling_type("ALL") @default_pooling_type("ALL")
class ModernBertForTokenClassification(nn.Module): class ModernBertForTokenClassification(nn.Module):
is_pooling_model = True is_pooling_model = True

View File

@@ -17,7 +17,7 @@ from collections.abc import Callable, Set
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import TypeVar from typing import TYPE_CHECKING, Any, TypeVar
import torch.nn as nn import torch.nn as nn
import transformers import transformers
@@ -33,6 +33,14 @@ from vllm.logging_utils import logtime
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
from vllm.utils.hashing import safe_hash from vllm.utils.hashing import safe_hash
if TYPE_CHECKING:
from vllm.config.model import AttnTypeStr
from vllm.config.pooler import PoolingTypeStr
else:
AttnTypeStr = Any
PoolingTypeStr = Any
from .interfaces import ( from .interfaces import (
has_inner_state, has_inner_state,
has_noops, has_noops,
@@ -47,6 +55,7 @@ from .interfaces import (
supports_transcription, supports_transcription,
) )
from .interfaces_base import ( from .interfaces_base import (
get_attn_type,
get_default_pooling_type, get_default_pooling_type,
is_pooling_model, is_pooling_model,
is_text_generation_model, is_text_generation_model,
@@ -509,7 +518,8 @@ class _ModelInfo:
architecture: str architecture: str
is_text_generation_model: bool is_text_generation_model: bool
is_pooling_model: bool is_pooling_model: bool
default_pooling_type: str attn_type: AttnTypeStr
default_pooling_type: PoolingTypeStr
supports_cross_encoding: bool supports_cross_encoding: bool
supports_multimodal: bool supports_multimodal: bool
supports_multimodal_raw_input_only: bool supports_multimodal_raw_input_only: bool
@@ -530,6 +540,7 @@ class _ModelInfo:
is_text_generation_model=is_text_generation_model(model), is_text_generation_model=is_text_generation_model(model),
is_pooling_model=is_pooling_model(model), is_pooling_model=is_pooling_model(model),
default_pooling_type=get_default_pooling_type(model), default_pooling_type=get_default_pooling_type(model),
attn_type=get_attn_type(model),
supports_cross_encoding=supports_cross_encoding(model), supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model), supports_multimodal=supports_multimodal(model),
supports_multimodal_raw_input_only=supports_multimodal_raw_input_only( supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(

View File

@@ -119,11 +119,12 @@ class EngineCore:
# Setup scheduler. # Setup scheduler.
Scheduler = vllm_config.scheduler_config.get_scheduler_cls() Scheduler = vllm_config.scheduler_config.get_scheduler_cls()
if len(kv_cache_config.kv_cache_groups) == 0: if len(kv_cache_config.kv_cache_groups) == 0: # noqa: SIM102
# Encoder models without KV cache don't support # Encoder models without KV cache don't support
# chunked prefill. But do SSM models? # chunked prefill. But do SSM models?
logger.info("Disabling chunked prefill for model without KVCache") if vllm_config.scheduler_config.enable_chunked_prefill:
vllm_config.scheduler_config.enable_chunked_prefill = False logger.warning("Disabling chunked prefill for model without KVCache")
vllm_config.scheduler_config.enable_chunked_prefill = False
scheduler_block_size = ( scheduler_block_size = (
vllm_config.cache_config.block_size vllm_config.cache_config.block_size