[Model] Update pooling model interface (#21058)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-07-18 00:05:40 +08:00
committed by GitHub
parent 9fb2d22032
commit 90bd2ab6e3
17 changed files with 247 additions and 345 deletions

View File

@@ -13,9 +13,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
from .interfaces import (SupportsCrossEncoding, SupportsMultiModal,
SupportsScoreTemplate)
@@ -72,6 +71,8 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
SupportsCrossEncoding,
SupportsMultiModal,
SupportsScoreTemplate):
is_pooling_model = True
weight_mapper = WeightsMapper(
orig_to_new_prefix={
"score.0.": "score.dense.",
@@ -95,7 +96,7 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
self.score = JinaVLScorer(config)
self._pooler = Pooler.from_config_with_defaults(
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
@@ -137,14 +138,6 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
logits = self.score(hidden_states) - self.LOGIT_BIAS
return logits
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.weight_mapper)