Improve enable chunked_prefill & prefix_caching logic. (#26623)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
wang.yuqi
2025-11-28 14:05:48 +08:00
committed by GitHub
parent 37b15e97e8
commit f4b76056ee
11 changed files with 456 additions and 133 deletions

View File

@@ -19,10 +19,14 @@ from vllm.utils.func_utils import supports_kw
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config.model import AttnTypeStr
from vllm.config.pooler import PoolingTypeStr
from vllm.model_executor.layers.pooler import Pooler
else:
VllmConfig = Any
Pooler = Any
PoolingTypeStr = Any
AttnTypeStr = Any
logger = init_logger(__name__)
@@ -165,7 +169,7 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
MRO of your model class.
"""
default_pooling_type: ClassVar[str] = "LAST"
default_pooling_type: ClassVar[PoolingTypeStr] = "LAST"
"""
Indicates the [vllm.config.pooler.PoolerConfig.pooling_type][]
to use by default.
@@ -175,6 +179,17 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
decorator to conveniently set this field.
"""
attn_type: ClassVar[AttnTypeStr] = "decoder"
"""
Indicates the
[vllm.config.model.ModelConfig.attn_type][]
to use by default.
You can use the
[vllm.model_executor.models.interfaces_base.attn_type][]
decorator to conveniently set this field.
"""
pooler: Pooler
"""The pooler is only called on TP rank 0."""
@@ -199,7 +214,7 @@ def is_pooling_model(
_T = TypeVar("_T", bound=type[nn.Module])
def default_pooling_type(pooling_type: str):
def default_pooling_type(pooling_type: PoolingTypeStr):
"""Decorator to set `VllmModelForPooling.default_pooling_type`."""
def func(model: _T) -> _T:
@@ -209,5 +224,19 @@ def default_pooling_type(pooling_type: str):
return func
def get_default_pooling_type(model: type[object] | object) -> str:
def get_default_pooling_type(model: type[object] | object) -> PoolingTypeStr:
return getattr(model, "default_pooling_type", "LAST")
def attn_type(attn_type: AttnTypeStr):
"""Decorator to set `VllmModelForPooling.attn_type`."""
def func(model: _T) -> _T:
model.attn_type = attn_type # type: ignore
return model
return func
def get_attn_type(model: type[object] | object) -> AttnTypeStr:
return getattr(model, "attn_type", "decoder")