[Model] Update pooling model interface (#21058)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import (TYPE_CHECKING, Optional, Protocol, Union, overload,
|
||||
runtime_checkable)
|
||||
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
|
||||
Union, overload, runtime_checkable)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -13,8 +12,7 @@ from vllm.utils import supports_kw
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import PoolerOutput
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.layers.pooler import Pooler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -130,16 +128,20 @@ def is_text_generation_model(
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class VllmModelForPooling(VllmModel[T], Protocol[T]):
|
||||
class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
|
||||
"""The interface required for all pooling models in vLLM."""
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: T,
|
||||
pooling_metadata: "PoolingMetadata",
|
||||
) -> "PoolerOutput":
|
||||
"""Only called on TP rank 0."""
|
||||
...
|
||||
is_pooling_model: ClassVar[Literal[True]] = True
|
||||
"""
|
||||
A flag that indicates this model supports pooling.
|
||||
|
||||
Note:
|
||||
There is no need to redefine this flag if this class is in the
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
pooler: "Pooler"
|
||||
"""The pooler is only called on TP rank 0."""
|
||||
|
||||
|
||||
@overload
|
||||
@@ -158,7 +160,4 @@ def is_pooling_model(
|
||||
if not is_vllm_model(model):
|
||||
return False
|
||||
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, VllmModelForPooling)
|
||||
|
||||
return isinstance(model, VllmModelForPooling)
|
||||
return getattr(model, "is_pooling_model", False)
|
||||
|
||||
Reference in New Issue
Block a user