[Model] Update pooling model interface (#21058)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-07-18 00:05:40 +08:00
committed by GitHub
parent 9fb2d22032
commit 90bd2ab6e3
17 changed files with 247 additions and 345 deletions

View File

@@ -11,11 +11,13 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.models.gemma2 import Gemma2Model from vllm.model_executor.models.gemma2 import Gemma2Model
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors
from vllm.sequence import IntermediateTensors, PoolerOutput
class MyGemma2Embedding(nn.Module): class MyGemma2Embedding(nn.Module):
is_pooling_model = True
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -24,7 +26,7 @@ class MyGemma2Embedding(nn.Module):
self.model = Gemma2Model(vllm_config=vllm_config, self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults( self.pooler = Pooler.from_config_with_defaults(
vllm_config.model_config.pooler_config, vllm_config.model_config.pooler_config,
pooling_type=PoolingType.LAST, pooling_type=PoolingType.LAST,
normalize=True, normalize=True,
@@ -54,13 +56,6 @@ class MyGemma2Embedding(nn.Module):
# Return all-zero embeddings # Return all-zero embeddings
return torch.zeros_like(hidden_states) return torch.zeros_like(hidden_states)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights = self.hf_to_vllm_mapper.apply(weights) weights = self.hf_to_vllm_mapper.apply(weights)

View File

@@ -1237,10 +1237,6 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
user: Optional[str] = None user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
# --8<-- [start:embedding-pooling-params]
additional_data: Optional[Any] = None
# --8<-- [end:embedding-pooling-params]
# --8<-- [start:embedding-extra-params] # --8<-- [start:embedding-extra-params]
add_special_tokens: bool = Field( add_special_tokens: bool = Field(
default=True, default=True,
@@ -1259,8 +1255,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
# --8<-- [end:embedding-extra-params] # --8<-- [end:embedding-extra-params]
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams(dimensions=self.dimensions, return PoolingParams(dimensions=self.dimensions)
additional_data=self.additional_data)
class EmbeddingChatRequest(OpenAIBaseModel): class EmbeddingChatRequest(OpenAIBaseModel):
@@ -1272,10 +1267,6 @@ class EmbeddingChatRequest(OpenAIBaseModel):
user: Optional[str] = None user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
# --8<-- [start:chat-embedding-pooling-params]
additional_data: Optional[Any] = None
# --8<-- [end:chat-embedding-pooling-params]
# --8<-- [start:chat-embedding-extra-params] # --8<-- [start:chat-embedding-extra-params]
add_special_tokens: bool = Field( add_special_tokens: bool = Field(
default=False, default=False,
@@ -1323,8 +1314,7 @@ class EmbeddingChatRequest(OpenAIBaseModel):
return data return data
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams(dimensions=self.dimensions, return PoolingParams(dimensions=self.dimensions)
additional_data=self.additional_data)
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
@@ -1340,10 +1330,6 @@ class ScoreRequest(OpenAIBaseModel):
text_2: Union[list[str], str, ScoreMultiModalParam] text_2: Union[list[str], str, ScoreMultiModalParam]
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
# --8<-- [start:score-pooling-params]
additional_data: Optional[Any] = None
# --8<-- [end:score-pooling-params]
# --8<-- [start:score-extra-params] # --8<-- [start:score-extra-params]
mm_processor_kwargs: Optional[dict[str, Any]] = Field( mm_processor_kwargs: Optional[dict[str, Any]] = Field(
@@ -1362,8 +1348,7 @@ class ScoreRequest(OpenAIBaseModel):
# --8<-- [end:score-extra-params] # --8<-- [end:score-extra-params]
def to_pooling_params(self, *, use_cross_encoder: bool = False): def to_pooling_params(self, *, use_cross_encoder: bool = False):
return PoolingParams(use_cross_encoder=use_cross_encoder, return PoolingParams(use_cross_encoder=use_cross_encoder)
additional_data=self.additional_data)
class RerankRequest(OpenAIBaseModel): class RerankRequest(OpenAIBaseModel):
@@ -1373,10 +1358,6 @@ class RerankRequest(OpenAIBaseModel):
top_n: int = Field(default_factory=lambda: 0) top_n: int = Field(default_factory=lambda: 0)
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
# --8<-- [start:rerank-pooling-params]
additional_data: Optional[Any] = None
# --8<-- [end:rerank-pooling-params]
# --8<-- [start:rerank-extra-params] # --8<-- [start:rerank-extra-params]
mm_processor_kwargs: Optional[dict[str, Any]] = Field( mm_processor_kwargs: Optional[dict[str, Any]] = Field(
@@ -1395,8 +1376,7 @@ class RerankRequest(OpenAIBaseModel):
# --8<-- [end:rerank-extra-params] # --8<-- [end:rerank-extra-params]
def to_pooling_params(self, *, use_cross_encoder: bool = False): def to_pooling_params(self, *, use_cross_encoder: bool = False):
return PoolingParams(use_cross_encoder=use_cross_encoder, return PoolingParams(use_cross_encoder=use_cross_encoder)
additional_data=self.additional_data)
class RerankDocument(BaseModel): class RerankDocument(BaseModel):
@@ -1534,10 +1514,6 @@ class ClassificationRequest(OpenAIBaseModel):
truncate_prompt_tokens: Optional[int] = None truncate_prompt_tokens: Optional[int] = None
user: Optional[str] = None user: Optional[str] = None
# --8<-- [start:classification-pooling-params]
additional_data: Optional[Any] = None
# --8<-- [end:classification-pooling-params]
# --8<-- [start:classification-extra-params] # --8<-- [start:classification-extra-params]
priority: int = Field( priority: int = Field(
default=0, default=0,
@@ -1550,7 +1526,7 @@ class ClassificationRequest(OpenAIBaseModel):
# --8<-- [end:classification-extra-params] # --8<-- [end:classification-extra-params]
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data) return PoolingParams()
class ClassificationData(OpenAIBaseModel): class ClassificationData(OpenAIBaseModel):

View File

@@ -3,22 +3,25 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum from enum import IntEnum
from typing import Callable, Optional, TypeVar, Union from typing import Callable, Literal, Optional, TypeVar, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from typing_extensions import assert_never
from vllm.config import ModelConfig, PoolerConfig from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.pooling_metadata import ( # noqa: E501 from vllm.model_executor.pooling_metadata import ( # noqa: E501
PoolingMetadata as V0PoolingMetadata) PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors from vllm.model_executor.pooling_metadata import PoolingTensors
from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.utils import resolve_obj_by_qualname from vllm.utils import resolve_obj_by_qualname
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
PoolingTask = Literal["encode", "embed", "classify", "score"]
class PoolingType(IntEnum): class PoolingType(IntEnum):
@@ -64,6 +67,48 @@ class ResolvedPoolingConfig:
) )
class Pooler(nn.Module, ABC):
"""The interface required for all poolers used in pooling models in vLLM."""
@staticmethod
def from_config_with_defaults(
pooler_config: PoolerConfig,
pooling_type: PoolingType,
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[list[int]] = None,
) -> "Pooler":
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=pooling_type,
normalize=normalize,
softmax=softmax,
step_tag_id=step_tag_id,
returned_token_ids=returned_token_ids,
)
if pooling_type == PoolingType.STEP:
return StepPooler.from_config(resolved_config)
return SimplePooler.from_config(resolved_config)
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
"""
Construct the pooling parameters to use for a task,
or `None` if the task is not supported.
"""
return None
@abstractmethod
def forward(
self,
hidden_states: Union[list[torch.Tensor], torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
raise NotImplementedError
def get_prompt_lens( def get_prompt_lens(
hidden_states: Union[torch.Tensor, list[torch.Tensor]], hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
@@ -104,17 +149,6 @@ def build_output(all_data: torch.Tensor) -> PoolerOutput:
return PoolerOutput(outputs=all_outputs) return PoolerOutput(outputs=all_outputs)
class BasePooler(nn.Module):
@abstractmethod
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
raise NotImplementedError
class PoolingMethod(nn.Module, ABC): class PoolingMethod(nn.Module, ABC):
@staticmethod @staticmethod
@@ -130,6 +164,10 @@ class PoolingMethod(nn.Module, ABC):
raise NotImplementedError(f"Unsupported method: {pooling_type}") raise NotImplementedError(f"Unsupported method: {pooling_type}")
@abstractmethod
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
raise NotImplementedError
@abstractmethod @abstractmethod
def forward_one( def forward_one(
self, self,
@@ -168,6 +206,14 @@ class PoolingMethod(nn.Module, ABC):
class CLSPool(PoolingMethod): class CLSPool(PoolingMethod):
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
# The equalities are split up to keep mypy happy
if (task == "encode" or task == "embed" or task == "classify"
or task == "score"):
return PoolingParams()
assert_never(task)
def forward_one( def forward_one(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -190,6 +236,14 @@ class CLSPool(PoolingMethod):
class LastPool(PoolingMethod): class LastPool(PoolingMethod):
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
# The equalities are split up to keep mypy happy
if (task == "encode" or task == "embed" or task == "classify"
or task == "score"):
return PoolingParams()
assert_never(task)
def forward_one( def forward_one(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -208,6 +262,16 @@ class LastPool(PoolingMethod):
class AllPool(PoolingMethod): class AllPool(PoolingMethod):
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
if task == "encode":
return PoolingParams()
# The equalities are split up to keep mypy happy
if task == "embed" or task == "classify" or task == "score":
return None
assert_never(task)
def forward_one( def forward_one(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -235,6 +299,14 @@ class AllPool(PoolingMethod):
class MeanPool(PoolingMethod): class MeanPool(PoolingMethod):
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
# The equalities are split up to keep mypy happy
if (task == "encode" or task == "embed" or task == "classify"
or task == "score"):
return PoolingParams()
assert_never(task)
def forward_one( def forward_one(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -345,25 +417,6 @@ class LambdaPoolerActivation(PoolerActivation):
class PoolerHead(nn.Module): class PoolerHead(nn.Module):
@classmethod
def from_config_with_defaults(
cls,
pooler_config: PoolerConfig,
pooling_type: PoolingType,
normalize: bool,
softmax: bool,
) -> "PoolerHead":
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=pooling_type,
normalize=normalize,
softmax=softmax,
step_tag_id=None,
returned_token_ids=None,
)
return cls.from_config(resolved_config)
@classmethod @classmethod
def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "PoolerHead": def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "PoolerHead":
if pooler_config.normalize and pooler_config.softmax: if pooler_config.normalize and pooler_config.softmax:
@@ -424,21 +477,17 @@ class PoolerHead(nn.Module):
return self.activation(pooled_data) return self.activation(pooled_data)
class SimplePooler(BasePooler): class SimplePooler(Pooler):
"""A layer that pools specific information from hidden states. """A layer that pools specific information from hidden states.
This layer does the following: This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method. 1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified. 2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`. 3. Returns structured results as `PoolerOutput`.
Attributes:
pooling_type: The type of pooling to use.
normalize: Whether to normalize the pooled data.
""" """
@classmethod @classmethod
def from_config_with_defaults( def from_config_with_defaults( # type: ignore[override]
cls, cls,
pooler_config: PoolerConfig, pooler_config: PoolerConfig,
pooling_type: PoolingType, pooling_type: PoolingType,
@@ -471,6 +520,9 @@ class SimplePooler(BasePooler):
self.pooling = pooling self.pooling = pooling
self.head = head self.head = head
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
return self.pooling.get_pooling_params(task)
def forward( def forward(
self, self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]], hidden_states: Union[torch.Tensor, list[torch.Tensor]],
@@ -481,7 +533,7 @@ class SimplePooler(BasePooler):
return build_output(pooled_data) return build_output(pooled_data)
class StepPooler(BasePooler): class StepPooler(Pooler):
@classmethod @classmethod
def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "StepPooler": def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "StepPooler":
@@ -543,6 +595,16 @@ class StepPooler(BasePooler):
return pooled_data return pooled_data
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
if task == "encode":
return PoolingParams(logits_processing_needs_token_ids=True)
# The equalities are split up to keep mypy happy
if task == "embed" or task == "classify" or task == "score":
return None
assert_never(task)
def forward( def forward(
self, self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]], hidden_states: Union[torch.Tensor, list[torch.Tensor]],
@@ -553,32 +615,6 @@ class StepPooler(BasePooler):
return build_output(pooled_data) return build_output(pooled_data)
class Pooler(nn.Module):
@staticmethod
def from_config_with_defaults(
pooler_config: PoolerConfig,
pooling_type: PoolingType,
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[list[int]] = None,
) -> BasePooler:
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=pooling_type,
normalize=normalize,
softmax=softmax,
step_tag_id=step_tag_id,
returned_token_ids=returned_token_ids,
)
if pooling_type == PoolingType.STEP:
return StepPooler.from_config(resolved_config)
return SimplePooler.from_config(resolved_config)
PoolingFn = Callable[ PoolingFn = Callable[
[Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata], [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
Union[torch.Tensor, list[torch.Tensor]]] Union[torch.Tensor, list[torch.Tensor]]]
@@ -618,6 +654,18 @@ class ClassifierPooler(nn.Module):
return (self.cross_encoder_act_fn return (self.cross_encoder_act_fn
if use_cross_encoder else self.classification_act_fn) if use_cross_encoder else self.classification_act_fn)
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
if task == "encode":
return PoolingParams()
if task == "embed":
return None
if task == "classify":
return PoolingParams()
if task == "score":
return PoolingParams(use_cross_encoder=True)
assert_never(task)
def forward( def forward(
self, self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]], hidden_states: Union[torch.Tensor, list[torch.Tensor]],

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -42,13 +42,14 @@ def _create_pooling_model_cls(
default_softmax: bool, default_softmax: bool,
) -> _T: ) -> _T:
# Lazy import # Lazy import
from vllm.model_executor.layers.pooler import Pooler, PoolerOutput from vllm.model_executor.layers.pooler import Pooler
from vllm.model_executor.pooling_metadata import PoolingMetadata
from .utils import AutoWeightsLoader, WeightsMapper from .utils import AutoWeightsLoader, WeightsMapper
class ModelForPooling(orig_cls, VllmModelForPooling): class ModelForPooling(orig_cls, VllmModelForPooling):
is_pooling_model = True
def __init__( def __init__(
self, self,
*, *,
@@ -66,27 +67,20 @@ def _create_pooling_model_cls(
delattr(self, attr) delattr(self, attr)
# If the model already defines a pooler instance, don't overwrite it # If the model already defines a pooler instance, don't overwrite it
if not getattr(self, "_pooler", None): if not getattr(self, "pooler", None):
self._init_pooler(vllm_config, prefix=prefix) self._init_pooler(vllm_config, prefix=prefix)
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
self._pooler = Pooler.from_config_with_defaults( self.pooler = Pooler.from_config_with_defaults(
pooler_config, pooler_config,
pooling_type=default_pooling_type, pooling_type=default_pooling_type,
normalize=default_normalize, normalize=default_normalize,
softmax=default_softmax, softmax=default_softmax,
) )
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# TODO: Support uninitialized params tracking # TODO: Support uninitialized params tracking
@@ -171,10 +165,8 @@ def as_seq_cls_model(cls: _T) -> _T:
# Lazy import # Lazy import
from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import (ClassifierPooler, from vllm.model_executor.layers.pooler import (ClassifierPooler,
PoolerOutput, PoolingType, PoolingType, SimplePooler)
SimplePooler)
from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix from .utils import maybe_prefix
@@ -213,7 +205,7 @@ def as_seq_cls_model(cls: _T) -> _T:
softmax=True, softmax=True,
) )
self._pooler = ClassifierPooler( self.pooler = ClassifierPooler(
vllm_config.model_config, vllm_config.model_config,
pooling=pooler.pooling, pooling=pooler.pooling,
classifier=self._classifier, classifier=self._classifier,
@@ -234,13 +226,6 @@ def as_seq_cls_model(cls: _T) -> _T:
return super().forward(input_ids, positions, intermediate_tensors, return super().forward(input_ids, positions, intermediate_tensors,
inputs_embeds) inputs_embeds)
def pooler(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
tokens = getattr(self.config, "classifier_from_token", None) tokens = getattr(self.config, "classifier_from_token", None)
method = getattr(self.config, "method", None) method = getattr(self.config, "method", None)

View File

@@ -18,12 +18,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler, from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
PoolingMethod, PoolingType) PoolingMethod, PoolingTask,
PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.pooling_params import PoolingParams
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
@@ -80,7 +82,7 @@ class BertEmbedding(nn.Module):
return embeddings return embeddings
class BertPooler(nn.Module): class BertPooler(Pooler):
def __init__(self, config: BertConfig): def __init__(self, config: BertConfig):
super().__init__() super().__init__()
@@ -89,6 +91,9 @@ class BertPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
return self.pooling.get_pooling_params(task)
def forward( def forward(
self, self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]], hidden_states: Union[torch.Tensor, list[torch.Tensor]],
@@ -319,6 +324,9 @@ class BertOutput(nn.Module):
class BertModel(nn.Module, SupportsQuant): class BertModel(nn.Module, SupportsQuant):
is_pooling_model = True
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]} packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
def __init__(self, def __init__(self,
@@ -403,12 +411,15 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
_pooler: An instance of Pooler used for pooling operations. _pooler: An instance of Pooler used for pooling operations.
""" """
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
self.model = self._build_model(vllm_config=vllm_config, self.model = self._build_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self._pooler = self._build_pooler(pooler_config) self.pooler = self._build_pooler(pooler_config)
def forward( def forward(
self, self,
@@ -422,13 +433,6 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors) intermediate_tensors=intermediate_tensors)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights_list = list(weights) weights_list = list(weights)
@@ -466,6 +470,8 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
_pooler: An instance of Pooler used for pooling operations. _pooler: An instance of Pooler used for pooling operations.
""" """
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
@@ -476,7 +482,7 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
embedding_class=BertEmbedding, embedding_class=BertEmbedding,
add_pooling_layer=True) add_pooling_layer=True)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self._pooler = ClassifierPooler( self.pooler = ClassifierPooler(
vllm_config.model_config, vllm_config.model_config,
pooling=self.bert.pooler, pooling=self.bert.pooler,
classifier=self.classifier, classifier=self.classifier,
@@ -487,13 +493,6 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
loaded_params = loader.load_weights(weights) loaded_params = loader.load_weights(weights)
return loaded_params return loaded_params
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],

View File

@@ -40,9 +40,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors
from ..layers.pooler import Pooler, PoolingType from ..layers.pooler import Pooler, PoolingType
from .interfaces import SupportsPP from .interfaces import SupportsPP
@@ -332,6 +331,8 @@ class GPT2ForSequenceClassification(nn.Module):
_pooler: An instance of Pooler used for pooling operations. _pooler: An instance of Pooler used for pooling operations.
""" """
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
@@ -339,7 +340,7 @@ class GPT2ForSequenceClassification(nn.Module):
prefix=maybe_prefix(prefix, "gpt2")) prefix=maybe_prefix(prefix, "gpt2"))
self.score = nn.Linear(config.n_embd, config.num_labels, bias=False) self.score = nn.Linear(config.n_embd, config.num_labels, bias=False)
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults( self.pooler = Pooler.from_config_with_defaults(
pooler_config, pooler_config,
pooling_type=PoolingType.LAST, pooling_type=PoolingType.LAST,
normalize=False, normalize=False,
@@ -349,13 +350,6 @@ class GPT2ForSequenceClassification(nn.Module):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,

View File

@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from array import array from array import array
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -195,6 +194,8 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
- "<|user|>\nPROMPT\n<|assistant|>\n" - "<|user|>\nPROMPT\n<|assistant|>\n"
""" """
is_pooling_model = True
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
@@ -214,11 +215,4 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
self._pooler = GritLMPooler(vllm_config.model_config) self.pooler = GritLMPooler(vllm_config.model_config)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

View File

@@ -119,13 +119,6 @@ class SupportsMultiModal(Protocol):
... ...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsMultiModalType(Protocol):
supports_multimodal: Literal[True]
@overload @overload
def supports_multimodal( def supports_multimodal(
model: type[object]) -> TypeIs[type[SupportsMultiModal]]: model: type[object]) -> TypeIs[type[SupportsMultiModal]]:
@@ -140,10 +133,7 @@ def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
def supports_multimodal( def supports_multimodal(
model: Union[type[object], object], model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]: ) -> Union[TypeIs[type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
if isinstance(model, type): return getattr(model, "supports_multimodal", False)
return isinstance(model, _SupportsMultiModalType)
return isinstance(model, SupportsMultiModal)
@runtime_checkable @runtime_checkable
@@ -174,13 +164,6 @@ class SupportsScoreTemplate(Protocol):
... ...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsScoreTemplateType(Protocol):
supports_score_template: Literal[True]
@overload @overload
def supports_score_template( def supports_score_template(
model: type[object]) -> TypeIs[type[SupportsScoreTemplate]]: model: type[object]) -> TypeIs[type[SupportsScoreTemplate]]:
@@ -195,11 +178,7 @@ def supports_score_template(model: object) -> TypeIs[SupportsScoreTemplate]:
def supports_score_template( def supports_score_template(
model: Union[type[object], object], model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsScoreTemplate]], TypeIs[SupportsScoreTemplate]]: ) -> Union[TypeIs[type[SupportsScoreTemplate]], TypeIs[SupportsScoreTemplate]]:
return getattr(model, "supports_score_template", False)
if isinstance(model, type):
return isinstance(model, _SupportsScoreTemplateType)
return isinstance(model, SupportsScoreTemplate)
@runtime_checkable @runtime_checkable
@@ -409,11 +388,6 @@ class HasInnerState(Protocol):
""" """
@runtime_checkable
class _HasInnerStateType(Protocol):
has_inner_state: ClassVar[Literal[True]]
@overload @overload
def has_inner_state(model: object) -> TypeIs[HasInnerState]: def has_inner_state(model: object) -> TypeIs[HasInnerState]:
... ...
@@ -427,10 +401,7 @@ def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]:
def has_inner_state( def has_inner_state(
model: Union[type[object], object] model: Union[type[object], object]
) -> Union[TypeIs[type[HasInnerState]], TypeIs[HasInnerState]]: ) -> Union[TypeIs[type[HasInnerState]], TypeIs[HasInnerState]]:
if isinstance(model, type): return getattr(model, "has_inner_state", False)
return isinstance(model, _HasInnerStateType)
return isinstance(model, HasInnerState)
@runtime_checkable @runtime_checkable
@@ -446,11 +417,6 @@ class IsAttentionFree(Protocol):
""" """
@runtime_checkable
class _IsAttentionFreeType(Protocol):
is_attention_free: ClassVar[Literal[True]]
@overload @overload
def is_attention_free(model: object) -> TypeIs[IsAttentionFree]: def is_attention_free(model: object) -> TypeIs[IsAttentionFree]:
... ...
@@ -464,10 +430,7 @@ def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]:
def is_attention_free( def is_attention_free(
model: Union[type[object], object] model: Union[type[object], object]
) -> Union[TypeIs[type[IsAttentionFree]], TypeIs[IsAttentionFree]]: ) -> Union[TypeIs[type[IsAttentionFree]], TypeIs[IsAttentionFree]]:
if isinstance(model, type): return getattr(model, "is_attention_free", False)
return isinstance(model, _IsAttentionFreeType)
return isinstance(model, IsAttentionFree)
@runtime_checkable @runtime_checkable
@@ -502,11 +465,6 @@ class IsHybrid(Protocol):
... ...
@runtime_checkable
class _IsHybridType(Protocol):
is_hybrid: ClassVar[Literal[True]]
@overload @overload
def is_hybrid(model: object) -> TypeIs[IsHybrid]: def is_hybrid(model: object) -> TypeIs[IsHybrid]:
... ...
@@ -520,10 +478,7 @@ def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]:
def is_hybrid( def is_hybrid(
model: Union[type[object], object] model: Union[type[object], object]
) -> Union[TypeIs[type[IsHybrid]], TypeIs[IsHybrid]]: ) -> Union[TypeIs[type[IsHybrid]], TypeIs[IsHybrid]]:
if isinstance(model, type): return getattr(model, "is_hybrid", False)
return isinstance(model, _IsHybridType)
return isinstance(model, IsHybrid)
@runtime_checkable @runtime_checkable
@@ -598,11 +553,6 @@ class HasNoOps(Protocol):
has_noops: ClassVar[Literal[True]] = True has_noops: ClassVar[Literal[True]] = True
@runtime_checkable
class _HasNoOpsType(Protocol):
has_noops: ClassVar[Literal[True]]
@overload @overload
def has_noops(model: object) -> TypeIs[HasNoOps]: def has_noops(model: object) -> TypeIs[HasNoOps]:
... ...
@@ -616,10 +566,7 @@ def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]:
def has_noops( def has_noops(
model: Union[type[object], object] model: Union[type[object], object]
) -> Union[TypeIs[type[HasNoOps]], TypeIs[HasNoOps]]: ) -> Union[TypeIs[type[HasNoOps]], TypeIs[HasNoOps]]:
if isinstance(model, type): return getattr(model, "has_noops", False)
return isinstance(model, _HasNoOpsType)
return isinstance(model, HasNoOps)
@runtime_checkable @runtime_checkable
@@ -643,11 +590,7 @@ def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]:
def _supports_cross_encoding( def _supports_cross_encoding(
model: Union[type[object], object], model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: ) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
return getattr(model, "supports_cross_encoding", False)
if isinstance(model, type):
return isinstance(model, SupportsCrossEncoding)
return isinstance(model, SupportsCrossEncoding)
def supports_cross_encoding( def supports_cross_encoding(
@@ -658,8 +601,9 @@ def supports_cross_encoding(
def has_step_pooler(model: Union[type[object], object]) -> bool: def has_step_pooler(model: Union[type[object], object]) -> bool:
"""Check if the model uses step pooler.""" """Check if the model uses step pooler."""
return is_pooling_model(model) and any( from vllm.model_executor.layers.pooler import StepPooler
type(module).__name__ == "StepPooler" for module in model.modules())
return is_pooling_model(model) and isinstance(model.pooler, StepPooler)
class SupportsQuant: class SupportsQuant:
@@ -770,10 +714,7 @@ def supports_transcription(model: object) -> TypeIs[SupportsTranscription]:
def supports_transcription( def supports_transcription(
model: Union[type[object], object], model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsTranscription]], TypeIs[SupportsTranscription]]: ) -> Union[TypeIs[type[SupportsTranscription]], TypeIs[SupportsTranscription]]:
if isinstance(model, type): return getattr(model, "supports_transcription", False)
return isinstance(model, SupportsTranscription)
return isinstance(model, SupportsTranscription)
@runtime_checkable @runtime_checkable
@@ -796,7 +737,4 @@ def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]:
def supports_v0_only( def supports_v0_only(
model: Union[type[object], object], model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]: ) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]:
if isinstance(model, type): return getattr(model, "supports_v0_only", False)
return isinstance(model, SupportsV0Only)
return isinstance(model, SupportsV0Only)

View File

@@ -1,8 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
from typing import (TYPE_CHECKING, Optional, Protocol, Union, overload, Union, overload, runtime_checkable)
runtime_checkable)
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -13,8 +12,7 @@ from vllm.utils import supports_kw
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import PoolerOutput from vllm.model_executor.layers.pooler import Pooler
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -130,16 +128,20 @@ def is_text_generation_model(
@runtime_checkable @runtime_checkable
class VllmModelForPooling(VllmModel[T], Protocol[T]): class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
"""The interface required for all pooling models in vLLM.""" """The interface required for all pooling models in vLLM."""
def pooler( is_pooling_model: ClassVar[Literal[True]] = True
self, """
hidden_states: T, A flag that indicates this model supports pooling.
pooling_metadata: "PoolingMetadata",
) -> "PoolerOutput": Note:
"""Only called on TP rank 0.""" There is no need to redefine this flag if this class is in the
... MRO of your model class.
"""
pooler: "Pooler"
"""The pooler is only called on TP rank 0."""
@overload @overload
@@ -158,7 +160,4 @@ def is_pooling_model(
if not is_vllm_model(model): if not is_vllm_model(model):
return False return False
if isinstance(model, type): return getattr(model, "is_pooling_model", False)
return isinstance(model, VllmModelForPooling)
return isinstance(model, VllmModelForPooling)

View File

@@ -28,9 +28,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
@@ -404,6 +403,8 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
class InternLM2ForRewardModel(InternLM2ForCausalLM): class InternLM2ForRewardModel(InternLM2ForCausalLM):
is_pooling_model = True
def __init__( def __init__(
self, self,
*, *,
@@ -428,7 +429,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
) )
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults( self.pooler = Pooler.from_config_with_defaults(
pooler_config, pooler_config,
pooling_type=PoolingType.ALL, pooling_type=PoolingType.ALL,
normalize=False, normalize=False,
@@ -446,10 +447,3 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
inputs_embeds) inputs_embeds)
logits, _ = self.v_head(hidden_states) logits, _ = self.v_head(hidden_states)
return logits return logits
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

View File

@@ -27,9 +27,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager, from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams) MambaCacheParams)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType from vllm.utils import LayerBlockType
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
@@ -563,6 +562,8 @@ def _is_moe_layer(name: str):
class JambaForSequenceClassification(JambaForCausalLM): class JambaForSequenceClassification(JambaForCausalLM):
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config, prefix=prefix)
@@ -590,16 +591,9 @@ class JambaForSequenceClassification(JambaForCausalLM):
softmax=False, softmax=False,
) )
self._pooler = ClassifierPooler( self.pooler = ClassifierPooler(
vllm_config.model_config, vllm_config.model_config,
pooling=pooler.pooling, pooling=pooler.pooling,
classifier=self.score, classifier=self.score,
act_fn=pooler.head.activation, act_fn=pooler.head.activation,
) )
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

View File

@@ -13,9 +13,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors
from .interfaces import (SupportsCrossEncoding, SupportsMultiModal, from .interfaces import (SupportsCrossEncoding, SupportsMultiModal,
SupportsScoreTemplate) SupportsScoreTemplate)
@@ -72,6 +71,8 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
SupportsCrossEncoding, SupportsCrossEncoding,
SupportsMultiModal, SupportsMultiModal,
SupportsScoreTemplate): SupportsScoreTemplate):
is_pooling_model = True
weight_mapper = WeightsMapper( weight_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
"score.0.": "score.dense.", "score.0.": "score.dense.",
@@ -95,7 +96,7 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
self.score = JinaVLScorer(config) self.score = JinaVLScorer(config)
self._pooler = Pooler.from_config_with_defaults( self.pooler = Pooler.from_config_with_defaults(
pooler_config, pooler_config,
pooling_type=PoolingType.LAST, pooling_type=PoolingType.LAST,
normalize=False, normalize=False,
@@ -137,14 +138,6 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
logits = self.score(hidden_states) - self.LOGIT_BIAS logits = self.score(hidden_states) - self.LOGIT_BIAS
return logits return logits
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.weight_mapper) return loader.load_weights(weights, mapper=self.weight_mapper)

View File

@@ -13,14 +13,16 @@ from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import (BasePooler, ClassifierPooler, from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
PoolingMethod, PoolingType) PoolingMethod, PoolingTask,
PoolingType)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.pooling_params import PoolingParams
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsCrossEncoding, SupportsV0Only from .interfaces import SupportsCrossEncoding, SupportsV0Only
from .utils import WeightsMapper, maybe_prefix from .utils import WeightsMapper, maybe_prefix
@@ -253,7 +255,7 @@ class ModernBertModel(nn.Module):
return norm_outputs return norm_outputs
class ModernBertPooler(BasePooler): class ModernBertPooler(Pooler):
def __init__(self, config: ModernBertConfig): def __init__(self, config: ModernBertConfig):
super().__init__() super().__init__()
@@ -268,6 +270,9 @@ class ModernBertPooler(BasePooler):
eps=config.norm_eps, eps=config.norm_eps,
bias=config.norm_bias) bias=config.norm_bias)
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
return self.pooling.get_pooling_params(task)
def forward( def forward(
self, self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]], hidden_states: Union[torch.Tensor, list[torch.Tensor]],
@@ -281,6 +286,8 @@ class ModernBertPooler(BasePooler):
class ModernBertForSequenceClassification(nn.Module, SupportsV0Only, class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
SupportsCrossEncoding): SupportsCrossEncoding):
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
@@ -288,7 +295,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
self.model = ModernBertModel(vllm_config=vllm_config, self.model = ModernBertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "modernbert")) prefix=maybe_prefix(prefix, "modernbert"))
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self._pooler = ClassifierPooler( self.pooler = ClassifierPooler(
vllm_config.model_config, vllm_config.model_config,
pooling=ModernBertPooler(config), pooling=ModernBertPooler(config),
classifier=self.classifier, classifier=self.classifier,
@@ -321,13 +328,6 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor], input_ids: Optional[torch.LongTensor],

View File

@@ -24,12 +24,13 @@ import torch.nn as nn
from transformers import BatchFeature from transformers import BatchFeature
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import (AllPool, PoolerHead,
PoolerIdentity, SimplePooler)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (IsAttentionFree, from vllm.model_executor.models.interfaces import (IsAttentionFree,
SupportsMultiModal, SupportsMultiModal,
SupportsV0Only) SupportsV0Only)
from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs) MultiModalInputs, MultiModalKwargs)
@@ -37,8 +38,7 @@ from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptUpdate) BaseProcessingInfo, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import (IntermediateTensors, PoolerOutput, from vllm.sequence import IntermediateTensors
PoolingSequenceGroupOutput)
class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo): class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
@@ -118,6 +118,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
SupportsV0Only): SupportsV0Only):
"""Prithvi Masked Autoencoder""" """Prithvi Masked Autoencoder"""
is_pooling_model = True
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"): if modality.startswith("image"):
@@ -162,6 +164,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
"Only SemanticSegmentationTask is supported for now " "Only SemanticSegmentationTask is supported for now "
"by PrithviGeospatialMAE.") "by PrithviGeospatialMAE.")
self.pooler = SimplePooler(AllPool(), PoolerHead(PoolerIdentity()))
def _parse_and_validate_multimodal_data( def _parse_and_validate_multimodal_data(
self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]: self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
@@ -189,7 +193,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
): ):
pixel_values, location_coords = ( pixel_values, location_coords = (
self._parse_and_validate_multimodal_data(**kwargs)) self._parse_and_validate_multimodal_data(**kwargs))
model_output = self.model(pixel_values, model_output = self.model(pixel_values,
@@ -197,13 +200,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
return model_output.output return model_output.output
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)])
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
params_list = [] params_list = []

View File

@@ -16,8 +16,7 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors
from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .qwen2 import Qwen2Model from .qwen2 import Qwen2Model
@@ -25,6 +24,10 @@ from .utils import AutoWeightsLoader, maybe_prefix
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
is_pooling_model = True
pooler: SimplePooler
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
@@ -61,7 +64,6 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
quant_config=quant_config, quant_config=quant_config,
return_bias=False), return_bias=False),
) )
self._pooler: SimplePooler
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
@@ -80,13 +82,6 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
logits = self.score(hidden_states) logits = self.score(hidden_states)
return logits return logits
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, loader = AutoWeightsLoader(self,
@@ -96,11 +91,11 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
class Qwen2ForRewardModel(Qwen2RewardBaseModel): class Qwen2ForRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config, prefix=""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config.model_config.hf_config.num_labels = 1 vllm_config.model_config.hf_config.num_labels = 1
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config, prefix=prefix)
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults( self.pooler = Pooler.from_config_with_defaults(
pooler_config, pooler_config,
pooling_type=PoolingType.ALL, pooling_type=PoolingType.ALL,
normalize=False, normalize=False,
@@ -109,11 +104,11 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config, prefix=""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config.model_config.hf_config.num_labels = 2 vllm_config.model_config.hf_config.num_labels = 2
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config, prefix=prefix)
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults( self.pooler = Pooler.from_config_with_defaults(
pooler_config, pooler_config,
pooling_type=PoolingType.STEP, pooling_type=PoolingType.STEP,
normalize=False, normalize=False,

View File

@@ -15,8 +15,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
maybe_prefix) maybe_prefix)
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors
from vllm.sequence import IntermediateTensors, PoolerOutput
from .bert_with_rope import BertWithRope, JinaRobertaModel from .bert_with_rope import BertWithRope, JinaRobertaModel
from .interfaces import SupportsCrossEncoding, SupportsV0Only from .interfaces import SupportsCrossEncoding, SupportsV0Only
@@ -165,6 +164,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
_pooler: An instance of Pooler used for pooling operations. _pooler: An instance of Pooler used for pooling operations.
""" """
is_pooling_model = True
jina_to_vllm_mapper = WeightsMapper( jina_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={ orig_to_new_substr={
'emb_ln': "embeddings.LayerNorm", 'emb_ln': "embeddings.LayerNorm",
@@ -188,7 +188,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
add_pooling_layer=False) add_pooling_layer=False)
self.classifier = RobertaClassificationHead(config) self.classifier = RobertaClassificationHead(config)
self._pooler = ClassifierPooler( self.pooler = ClassifierPooler(
vllm_config.model_config, vllm_config.model_config,
pooling=CLSPool(), pooling=CLSPool(),
classifier=self.classifier, classifier=self.classifier,
@@ -198,13 +198,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper) return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Optional
import msgspec import msgspec
@@ -15,24 +15,31 @@ class PoolingParams(
msgspec.Struct, msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg] omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg] array_like=True): # type: ignore[call-arg]
"""API parameters for pooling models. This is currently a placeholder. """API parameters for pooling models. This
Attributes: Attributes:
dimensions: Reduce the dimensions of embeddings dimensions: Reduce the dimensions of embeddings
if model support matryoshka representation. if model support matryoshka representation.
additional_data: Any additional data needed for pooling.
""" """
dimensions: Optional[int] = None dimensions: Optional[int] = None
use_cross_encoder: bool = False use_cross_encoder: bool = False
additional_data: Optional[Any] = None """Internal use only."""
logits_processing_needs_token_ids: bool = False
"""Internal use only."""
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
def clone(self) -> "PoolingParams": def clone(self) -> "PoolingParams":
"""Returns a deep copy of the PoolingParams instance.""" """Returns a deep copy of the PoolingParams instance."""
return PoolingParams(dimensions=self.dimensions, return PoolingParams(
dimensions=self.dimensions,
use_cross_encoder=self.use_cross_encoder, use_cross_encoder=self.use_cross_encoder,
additional_data=self.additional_data) logits_processing_needs_token_ids=self.
logits_processing_needs_token_ids,
)
def verify(self, model_config: "ModelConfig") -> None: def verify(self, model_config: "ModelConfig") -> None:
if self.dimensions is not None: if self.dimensions is not None:
@@ -54,10 +61,12 @@ class PoolingParams(
raise ValueError("Dimensions must be greater than 0") raise ValueError("Dimensions must be greater than 0")
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"PoolingParams(" return (
f"PoolingParams("
f"dimensions={self.dimensions}, " f"dimensions={self.dimensions}, "
f"use_cross_encoder={self.use_cross_encoder}, " f"use_cross_encoder={self.use_cross_encoder}, "
f"additional_metadata={self.additional_data})") f"logits_processing_needs_token_ids={self.logits_processing_needs_token_ids})"
)
def __post_init__(self) -> None: def __post_init__(self) -> None:
assert self.output_kind == RequestOutputKind.FINAL_ONLY,\ assert self.output_kind == RequestOutputKind.FINAL_ONLY,\