[Model] Pooling models default to using chunked prefill & prefix caching if supported. (#20930)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2025-08-12 00:41:37 +08:00
committed by GitHub
parent 16fb668b61
commit 84cf78acee
31 changed files with 452 additions and 261 deletions

View File

@@ -28,7 +28,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from .interfaces import SupportsCrossEncoding, SupportsQuant
from .interfaces import (SupportsCrossEncoding, SupportsQuant,
default_pooling_type)
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
@@ -327,6 +328,7 @@ class BertOutput(nn.Module):
@support_torch_compile
@default_pooling_type("CLS")
class BertModel(nn.Module, SupportsQuant):
is_pooling_model = True
@@ -401,6 +403,7 @@ class BertModel(nn.Module, SupportsQuant):
return loaded_params
@default_pooling_type("ALL")
class BertPoolingModel(BertModel):
is_pooling_model = True
@@ -431,6 +434,7 @@ class BertPoolingModel(BertModel):
return loaded_params
@default_pooling_type("CLS")
class BertEmbeddingModel(nn.Module, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.
@@ -486,13 +490,8 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
return DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),
"embed":
Pooler.for_embed(
pooler_config,
default_pooling_type=PoolingType.CLS,
),
"encode": Pooler.for_encode(pooler_config),
"embed": Pooler.for_embed(pooler_config),
})
@@ -541,6 +540,7 @@ def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
return token_type_ids
@default_pooling_type("CLS")
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.