[Model] Update pooling model interface (#21058)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-07-18 00:05:40 +08:00
committed by GitHub
parent 9fb2d22032
commit 90bd2ab6e3
17 changed files with 247 additions and 345 deletions

View File

@@ -24,12 +24,13 @@ import torch.nn as nn
from transformers import BatchFeature
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import (AllPool, PoolerHead,
PoolerIdentity, SimplePooler)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (IsAttentionFree,
SupportsMultiModal,
SupportsV0Only)
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs)
@@ -37,8 +38,7 @@ from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import (IntermediateTensors, PoolerOutput,
PoolingSequenceGroupOutput)
from vllm.sequence import IntermediateTensors
class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
@@ -116,7 +116,9 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
dummy_inputs=PrithviGeoSpatialMAEInputBuilder)
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
SupportsV0Only):
""" Prithvi Masked Autoencoder"""
"""Prithvi Masked Autoencoder"""
is_pooling_model = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@@ -162,6 +164,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
"Only SemanticSegmentationTask is supported for now "
"by PrithviGeospatialMAE.")
self.pooler = SimplePooler(AllPool(), PoolerHead(PoolerIdentity()))
def _parse_and_validate_multimodal_data(
self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
@@ -189,7 +193,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
):
pixel_values, location_coords = (
self._parse_and_validate_multimodal_data(**kwargs))
model_output = self.model(pixel_values,
@@ -197,13 +200,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
return model_output.output
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)])
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_list = []