[Model] Update pooling model interface (#21058)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-07-18 00:05:40 +08:00
committed by GitHub
parent 9fb2d22032
commit 90bd2ab6e3
17 changed files with 247 additions and 345 deletions

View File

@@ -1,8 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import (TYPE_CHECKING, Optional, Protocol, Union, overload,
runtime_checkable)
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
Union, overload, runtime_checkable)
import torch
import torch.nn as nn
@@ -13,8 +12,7 @@ from vllm.utils import supports_kw
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import PoolerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.layers.pooler import Pooler
from vllm.model_executor.sampling_metadata import SamplingMetadata
logger = init_logger(__name__)
@@ -130,16 +128,20 @@ def is_text_generation_model(
@runtime_checkable
class VllmModelForPooling(VllmModel[T], Protocol[T]):
class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
"""The interface required for all pooling models in vLLM."""
def pooler(
self,
hidden_states: T,
pooling_metadata: "PoolingMetadata",
) -> "PoolerOutput":
"""Only called on TP rank 0."""
...
is_pooling_model: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports pooling.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
pooler: "Pooler"
"""The pooler is only called on TP rank 0."""
@overload
@@ -158,7 +160,4 @@ def is_pooling_model(
if not is_vllm_model(model):
return False
if isinstance(model, type):
return isinstance(model, VllmModelForPooling)
return isinstance(model, VllmModelForPooling)
return getattr(model, "is_pooling_model", False)