[Core][Model] PrithviMAE Enablement on vLLM v1 engine (#20577)

Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
This commit is contained in:
Christian Pinto
2025-07-23 19:00:23 +01:00
committed by GitHub
parent 316b1bf706
commit 8560a5b258
15 changed files with 704 additions and 238 deletions

View File

@@ -16,6 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model."""
from collections.abc import Iterable, Mapping, Sequence
from typing import Optional, Union
@@ -27,13 +28,14 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import (AllPool, PoolerHead,
PoolerIdentity, SimplePooler)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (IsAttentionFree,
SupportsMultiModal,
SupportsV0Only)
from vllm.model_executor.models.interfaces import (
IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput)
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs)
MultiModalFieldElem, MultiModalInputs,
MultiModalKwargs, MultiModalKwargsItem,
MultiModalSharedField, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptUpdate)
@@ -62,8 +64,9 @@ class PrithviGeoSpatialMAEInputBuilder(
# The size of pixel_values might change in the cases where we resize
# the input but never exceeds the dimensions below.
return {
"pixel_values": torch.full((1, 6, 512, 512), 1.0),
"location_coords": torch.full((1, 2), 1.0),
"pixel_values": torch.full((6, 512, 512), 1.0,
dtype=torch.float16),
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
}
@@ -75,8 +78,10 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
location_coords=MultiModalFieldConfig.batched("image"),
pixel_values=MultiModalFieldConfig.shared(batch_size=1,
modality="image"),
location_coords=MultiModalFieldConfig.shared(batch_size=1,
modality="image"),
)
def _get_prompt_updates(
@@ -99,23 +104,48 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
for k, v in mm_data.items():
mm_kwargs[k] = v
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
# This model receives in input a multi-dimensional tensor representing
# a single image patch and therefore it is not to be split
# into multiple elements, but rather to be considered a single one.
# Hence, the decision of using a MultiModalSharedField.
# The expected shape is (num_channels, width, height).
# This model however allows the user to also submit multiple image
# patches as a batch, adding a further dimension to the above shape.
# At this stage we only support submitting one patch per request and
# batching is achieved via vLLM batching.
# TODO (christian-pinto): enable support for multi patch requests
# in tandem with vLLM batching.
multimodal_kwargs_items = [
MultiModalKwargsItem.from_elems([
MultiModalFieldElem(
modality="image",
key=key,
data=data,
field=MultiModalSharedField(1),
) for key, data in mm_kwargs.items()
])
]
return MultiModalInputs(
type="multimodal",
prompt=prompt,
prompt_token_ids=[1],
mm_kwargs=MultiModalKwargs(mm_kwargs),
mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items),
mm_hashes=None,
mm_placeholders={},
mm_placeholders=mm_placeholders,
)
@MULTIMODAL_REGISTRY.register_processor(
PrithviGeoSpatialMAEMultiModalProcessor,
info=PrithviGeoSpatialMAEProcessingInfo,
dummy_inputs=PrithviGeoSpatialMAEInputBuilder)
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
SupportsV0Only):
dummy_inputs=PrithviGeoSpatialMAEInputBuilder,
)
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree,
SupportsMultiModalWithRawInput):
"""Prithvi Masked Autoencoder"""
is_pooling_model = True
@@ -128,10 +158,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
raise ValueError("Only image modality is supported")
def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
# We might be able/need to support different tasks with this same model
if config["task_args"]["task"] == "SemanticSegmentationTask":
from terratorch.cli_tools import SemanticSegmentationTask
task = SemanticSegmentationTask(
config["model_args"],
config["task_args"]["model_factory"],
@@ -144,7 +174,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
scheduler_hparams=config["scheduler_params"],
plot_on_val=config["task_args"]["plot_on_val"],
freeze_decoder=config["task_args"]["freeze_decoder"],
freeze_backbone=config["task_args"]["freeze_backbone"])
freeze_backbone=config["task_args"]["freeze_backbone"],
)
return task.model
else:
@@ -168,12 +199,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
def _parse_and_validate_multimodal_data(
self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
pixel_values = kwargs.pop("pixel_values", None)
if not isinstance(pixel_values, torch.Tensor):
raise ValueError(f"Incorrect type of pixel_values. "
f"Got type: {type(pixel_values)}")
pixel_values = torch.unbind(pixel_values, dim=0)[0]
location_coords = kwargs.pop("location_coords", None)
if not isinstance(location_coords, torch.Tensor):
@@ -185,6 +214,17 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
return pixel_values, location_coords
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
# We do not really use any input tokens and therefore no embeddings
# to be calculated. However, due to the mandatory token ids in
# the input prompt we pass one token and the size of the dummy
# embedding tensors must reflect that.
return torch.empty((input_ids.shape[0], 0))
def forward(
self,
input_ids: Optional[torch.Tensor],