[Model] Explicit default_pooling_type interface (#23736)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -144,6 +144,17 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
default_pooling_type: ClassVar[str] = "LAST"
|
||||
"""
|
||||
Indicates the
|
||||
[vllm.model_executor.layers.pooler.PoolerConfig.pooling_type][]
|
||||
to use by default.
|
||||
|
||||
You can use the
|
||||
[vllm.model_executor.models.interfaces_base.default_pooling_type][]
|
||||
decorator to conveniently set this field.
|
||||
"""
|
||||
|
||||
pooler: Pooler
|
||||
"""The pooler is only called on TP rank 0."""
|
||||
|
||||
@@ -165,3 +176,20 @@ def is_pooling_model(
|
||||
return False
|
||||
|
||||
return getattr(model, "is_pooling_model", False)
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound=type[nn.Module])
|
||||
|
||||
|
||||
def default_pooling_type(pooling_type: str):
|
||||
"""Decorator to set `VllmModelForPooling.default_pooling_type`."""
|
||||
|
||||
def func(model: _T) -> _T:
|
||||
model.default_pooling_type = pooling_type # type: ignore
|
||||
return model
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def get_default_pooling_type(model: Union[type[object], object]) -> str:
|
||||
return getattr(model, "default_pooling_type", "LAST")
|
||||
|
||||
Reference in New Issue
Block a user