[Model] Pooling models default to using chunked prefill & prefix caching if supported. (#20930)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@@ -28,7 +28,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import PoolingTask
|
||||
|
||||
from .interfaces import SupportsCrossEncoding, SupportsQuant
|
||||
from .interfaces import (SupportsCrossEncoding, SupportsQuant,
|
||||
default_pooling_type)
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
|
||||
|
||||
@@ -327,6 +328,7 @@ class BertOutput(nn.Module):
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
@default_pooling_type("CLS")
|
||||
class BertModel(nn.Module, SupportsQuant):
|
||||
|
||||
is_pooling_model = True
|
||||
@@ -401,6 +403,7 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
return loaded_params
|
||||
|
||||
|
||||
@default_pooling_type("ALL")
|
||||
class BertPoolingModel(BertModel):
|
||||
|
||||
is_pooling_model = True
|
||||
@@ -431,6 +434,7 @@ class BertPoolingModel(BertModel):
|
||||
return loaded_params
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
class BertEmbeddingModel(nn.Module, SupportsQuant):
|
||||
"""A model that uses Bert to provide embedding functionalities.
|
||||
|
||||
@@ -486,13 +490,8 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
|
||||
|
||||
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
||||
return DispatchPooler({
|
||||
"encode":
|
||||
Pooler.for_encode(pooler_config),
|
||||
"embed":
|
||||
Pooler.for_embed(
|
||||
pooler_config,
|
||||
default_pooling_type=PoolingType.CLS,
|
||||
),
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
})
|
||||
|
||||
|
||||
@@ -541,6 +540,7 @@ def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return token_type_ids
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
SupportsQuant):
|
||||
"""A model that uses Bert to provide embedding functionalities.
|
||||
|
||||
Reference in New Issue
Block a user