[VLM] Generalized prompt updates for multi-modal processor (#13964)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-02-28 01:44:25 +08:00
committed by GitHub
parent 7864875879
commit f1579b229d
29 changed files with 629 additions and 486 deletions

View File

@@ -15,7 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model."""
from typing import Iterable, Mapping, Optional, Set, Tuple, Union
from collections.abc import Iterable, Mapping, Sequence
from typing import Optional, Set, Tuple, Union
import torch
import torch.nn as nn
@@ -32,7 +33,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
BaseProcessingInfo, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import (IntermediateTensors, PoolerOutput,
PoolingSequenceGroupOutput)
@@ -44,7 +45,7 @@ class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
pass
return {"image": 0}
class PrithviGeoSpatialMAEInputBuilder(
@@ -78,20 +79,13 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
location_coords=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
pass
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
pass
) -> Sequence[PromptUpdate]:
return []
def apply(
self,
@@ -120,7 +114,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
""" Prithvi Masked Autoencoder"""
def _instantiate_model(self, config: dict) -> nn.Module | None:
def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
# We might be able/need to support different tasks with this same model
if config["task_args"]["task"] == "SemanticSegmentationTask":
@@ -158,7 +152,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
"by PrithviGeospatialMAE.")
def _parse_and_validate_multimodal_data(
self, **kwargs) -> Tuple[torch.Tensor, torch.Tensor | None]:
self, **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
pixel_values = kwargs.pop("pixel_values", None)
if not isinstance(pixel_values, torch.Tensor):