[Model] IBM/NASA Prithvi Geospatial model (#12830)
This commit is contained in:
@@ -320,9 +320,14 @@ class PlaceholderAttentionMetadataBuilder(
|
||||
-1 if cuda graph is not used.
|
||||
batch_size: The maybe padded batch size.
|
||||
"""
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
# Some input builders such as ModelInputForCPUBuilder do not have the
|
||||
# "inter_data_list" attribute.
|
||||
# Let's check inter_data_list exists before we reference it.
|
||||
if hasattr(self.input_builder, "inter_data_list"):
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
@@ -254,8 +254,14 @@ class InputPreprocessor:
|
||||
Apply the model's multi-modal processor to a multi-modal prompt,
|
||||
returning the corresponding token IDs and metadata.
|
||||
"""
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
|
||||
# At the moment on model (PrithviGeoSpatialMAE) requires to be
|
||||
# initialized without a tokenizer while using also multi-modal
|
||||
# input.
|
||||
if not self.tokenizer:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
|
||||
|
||||
mm_processor = self.mm_registry.create_processor(
|
||||
self.model_config, tokenizer)
|
||||
@@ -273,9 +279,15 @@ class InputPreprocessor:
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> MultiModalInputs:
|
||||
"""Async version of :meth:`_process_multimodal`."""
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
tokenizer = await tokenizer_group.get_lora_tokenizer_async(lora_request
|
||||
)
|
||||
# At the moment on model (PrithviGeoSpatialMAE) requires to be
|
||||
# initialized without a tokenizer while using also multi-modal
|
||||
# input.
|
||||
if not self.tokenizer:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
tokenizer = await tokenizer_group.get_lora_tokenizer_async(
|
||||
lora_request)
|
||||
|
||||
mm_processor = self.mm_registry.create_processor(
|
||||
self.model_config, tokenizer)
|
||||
|
||||
238
vllm/model_executor/models/prithvi_geospatial_mae.py
Normal file
238
vllm/model_executor/models/prithvi_geospatial_mae.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Copyright 2025 The vLLM team.
|
||||
# Copyright 2025 IBM.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only IBM/NASA Prithvi Geospatial model."""
|
||||
from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (IsAttentionFree,
|
||||
SupportsMultiModal)
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargs)
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import (IntermediateTensors, PoolerOutput,
|
||||
PoolingSequenceGroupOutput)
|
||||
|
||||
|
||||
class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
pass
|
||||
|
||||
|
||||
class PrithviGeoSpatialMAEInputBuilder(
|
||||
BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
return ProcessorInputs(
|
||||
prompt_text="",
|
||||
# This model input is fixed and is in the form of a torch Tensor.
|
||||
# The size of pixel_values might change in the cases where we resize
|
||||
# the input but never exceeds the dimensions below.
|
||||
mm_data={
|
||||
"pixel_values": torch.full((1, 6, 512, 512), 1.0),
|
||||
"location_coords": torch.full((1, 2), 1.0)
|
||||
})
|
||||
|
||||
|
||||
class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
location_coords=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
pass
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputs:
|
||||
mm_kwargs = {}
|
||||
|
||||
for k, v in mm_data.items():
|
||||
mm_kwargs[k] = v
|
||||
|
||||
return MultiModalInputs(
|
||||
type="multimodal",
|
||||
prompt=prompt,
|
||||
prompt_token_ids=[1],
|
||||
mm_kwargs=MultiModalKwargs(mm_kwargs),
|
||||
mm_placeholders={},
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
PrithviGeoSpatialMAEMultiModalProcessor,
|
||||
info=PrithviGeoSpatialMAEProcessingInfo,
|
||||
dummy_inputs=PrithviGeoSpatialMAEInputBuilder)
|
||||
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
""" Prithvi Masked Autoencoder"""
|
||||
|
||||
def _instantiate_model(self, config: dict) -> nn.Module | None:
|
||||
|
||||
# We might be able/need to support different tasks with this same model
|
||||
if config["task_args"]["task"] == "SemanticSegmentationTask":
|
||||
from terratorch.cli_tools import SemanticSegmentationTask
|
||||
task = SemanticSegmentationTask(
|
||||
config["model_args"],
|
||||
config["task_args"]["model_factory"],
|
||||
loss=config["task_args"]["loss"],
|
||||
lr=config["task_args"]["lr"],
|
||||
ignore_index=config["task_args"]["ignore_index"],
|
||||
optimizer=config["task_args"]["optimizer"],
|
||||
optimizer_hparams=config["optimizer_params"],
|
||||
scheduler=config["task_args"]["scheduler"],
|
||||
scheduler_hparams=config["scheduler_params"],
|
||||
plot_on_val=config["task_args"]["plot_on_val"],
|
||||
freeze_decoder=config["task_args"]["freeze_decoder"],
|
||||
freeze_backbone=config["task_args"]["freeze_backbone"])
|
||||
|
||||
return task.model
|
||||
else:
|
||||
return None
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
# the actual model is dynamically instantiated using terratorch
|
||||
# allowing us to perform changes to the model architecture
|
||||
# at startup time (e.g., change the model decoder class.)
|
||||
self.model = self._instantiate_model(
|
||||
vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"])
|
||||
if self.model is None:
|
||||
raise ValueError(
|
||||
"Unsupported task."
|
||||
"Only SemanticSegmentationTask is supported for now"
|
||||
"by PrithviGeospatialMAE.")
|
||||
|
||||
def _parse_and_validate_multimodal_data(
|
||||
self, **kwargs) -> Tuple[torch.Tensor, torch.Tensor | None]:
|
||||
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError(f"Incorrect type of pixel_values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
pixel_values = torch.unbind(pixel_values, dim=0)[0]
|
||||
|
||||
location_coords = kwargs.pop("location_coords", None)
|
||||
if not isinstance(location_coords, torch.Tensor):
|
||||
raise ValueError(f"Incorrect type of location_coords. "
|
||||
f"Got type: {type(location_coords)}")
|
||||
location_coords = torch.unbind(location_coords, dim=0)[0]
|
||||
if location_coords.shape == torch.Size([0]):
|
||||
location_coords = None
|
||||
|
||||
return pixel_values, location_coords
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
):
|
||||
|
||||
pixel_values, location_coords = (
|
||||
self._parse_and_validate_multimodal_data(**kwargs))
|
||||
model_output = self.model(pixel_values,
|
||||
location_coords=location_coords)
|
||||
|
||||
return model_output.output
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)])
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_list = []
|
||||
model_buffers = dict(self.named_buffers())
|
||||
loaded_buffers = []
|
||||
for key, value in weights:
|
||||
if key == "state_dict":
|
||||
weights_to_parse = value
|
||||
for name, weight in weights_to_parse.items():
|
||||
if "pos_embed" in name:
|
||||
continue
|
||||
|
||||
if "_timm_module." in name:
|
||||
name = name.replace("_timm_module.", "")
|
||||
|
||||
# this model requires a couple of buffers to be loaded
|
||||
# that are not loadable with the AutoWeightsLoader
|
||||
if name in model_buffers:
|
||||
if "_timm_module." in name:
|
||||
name = name.replace("_timm_module.", "")
|
||||
buffer = model_buffers[name]
|
||||
weight_loader = getattr(buffer, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(buffer, weight)
|
||||
loaded_buffers.append(name)
|
||||
else:
|
||||
params_list.append((name, weight))
|
||||
break
|
||||
|
||||
# Load the remaining model parameters
|
||||
loader = AutoWeightsLoader(self)
|
||||
autoloaded_weights = loader.load_weights(params_list)
|
||||
|
||||
return autoloaded_weights.union(set(loaded_buffers))
|
||||
@@ -137,6 +137,10 @@ _EMBEDDING_MODELS = {
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
||||
# [Auto-converted (see adapters.py)]
|
||||
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
|
||||
# Technically PrithviGeoSpatialMAE is a model that works on images, both in
|
||||
# input and output. I am adding it here because it piggy-backs on embedding
|
||||
# models for the time being.
|
||||
"PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
|
||||
}
|
||||
|
||||
_CROSS_ENCODER_MODELS = {
|
||||
|
||||
@@ -74,7 +74,16 @@ class PoolingModelRunner(
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
decode_meta = model_input.attn_metadata.decode_metadata
|
||||
virtual_engine = model_input.virtual_engine
|
||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||
# Pooling models are (ab-)used also to integrate non text models that
|
||||
# are not autoregressive (PrithviGeosaptialMAE).
|
||||
# These model might not use attention and do not really have a prefill
|
||||
# and decode phase. The model input is processed in one shot and both
|
||||
# decode_metadata and prefill_metadata would be None for such models.
|
||||
# See the PlaceholderAttentionMetadata class.
|
||||
# TODO: Figure out if cuda_graph is of any use for these models and
|
||||
# explore how to leverage it.
|
||||
if (prefill_meta is None and decode_meta is not None
|
||||
and decode_meta.use_cuda_graph):
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = self.graph_runners[virtual_engine][
|
||||
|
||||
Reference in New Issue
Block a user