[Core][Model] Terratorch backend integration (#23513)

Signed-off-by: Michele Gazzetti <michele.gazzetti1@ibm.com>
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
Co-authored-by: Christian Pinto <christian.pinto@ibm.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
mgazz
2025-09-04 08:22:41 +01:00
committed by GitHub
parent e7fc70016f
commit 51d5e9be7d
23 changed files with 305 additions and 208 deletions

View File

@@ -184,10 +184,11 @@ _EMBEDDING_MODELS = {
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
# Technically PrithviGeoSpatialMAE is a model that works on images, both in
# input and output. I am adding it here because it piggybacks on embedding
# Technically Terratorch models work on images, both in
# input and output. I am adding it here because it piggy-backs on embedding
# models for the time being.
"PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
"PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
"Terratorch": ("terratorch", "Terratorch"),
}
_CROSS_ENCODER_MODELS = {
@@ -639,6 +640,9 @@ class _ModelRegistry:
model_info = self._try_inspect_model_cls(arch)
if model_info is not None:
return (model_info, arch)
elif model_config.model_impl == ModelImpl.TERRATORCH:
model_info = self._try_inspect_model_cls("Terratorch")
return (model_info, "Terratorch")
# Fallback to transformers impl (after resolving convert_type)
if (all(arch not in self.models for arch in architectures)
@@ -687,6 +691,11 @@ class _ModelRegistry:
model_cls = self._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
elif model_config.model_impl == ModelImpl.TERRATORCH:
arch = "Terratorch"
model_cls = self._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
# Fallback to transformers impl (after resolving convert_type)
if (all(arch not in self.models for arch in architectures)

View File

@@ -15,13 +15,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model."""
"""Wrapper around `Terratorch` models"""
from collections import OrderedDict
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union
import torch
import torch.nn as nn
from terratorch.vllm import (DummyDataGenerator, InferenceRunner,
InputDefinition, InputTypeEnum)
from transformers import BatchFeature
from vllm.config import VllmConfig
@@ -29,6 +32,7 @@ from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargsItems,
@@ -45,52 +49,46 @@ from .interfaces import (IsAttentionFree, MultiModalEmbeddings,
from .interfaces_base import default_pooling_type
def _prithvi_field_config(hf_inputs: Mapping[str, torch.Tensor]):
# This model receives in input a multi-dimensional tensor representing
# a single image patch and therefore it is not to be split
# into multiple elements, but rather to be considered a single one.
# Hence, the decision of using a MultiModalSharedField.
# The expected shape is (num_channels, width, height).
# This model however allows the user to also submit multiple image
# patches as a batch, adding a further dimension to the above shape.
# At this stage we only support submitting one patch per request and
# batching is achieved via vLLM batching.
# TODO (christian-pinto): enable support for multi patch requests
# in tandem with vLLM batching.
return dict(
pixel_values=MultiModalFieldConfig.shared(batch_size=1,
modality="image"),
location_coords=MultiModalFieldConfig.shared(batch_size=1,
modality="image"),
)
def _terratorch_field_names(pretrained_cfg: dict):
input_definition = InputDefinition(**pretrained_cfg["input"])
return set(input_definition.data.keys())
class PrithviGeoSpatialMAEMultiModalDataParser(MultiModalDataParser):
def _terratorch_field_factory(
pretrained_cfg: dict
) -> Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
]:
def _parse_image_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="image",
required_fields={"pixel_values", "location_coords"},
fields_factory=_prithvi_field_config,
)
def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]):
input_definition = InputDefinition(**pretrained_cfg["input"])
fields = {}
for input_name, input in input_definition.data.items():
if input.type == InputTypeEnum.tensor:
fields[input_name] = "image"
return super()._parse_image_data(data)
mm_fields_config = {}
for field_name, field_modality in fields.items():
mm_fields_config[field_name] = MultiModalFieldConfig.shared(
batch_size=1, modality=field_modality)
return mm_fields_config
return _terratorch_field_config
class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
class TerratorchProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
class PrithviGeoSpatialMAEInputBuilder(
BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]):
class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
def __init__(self, info: TerratorchProcessingInfo):
super().__init__(info)
self.dummy_data_generator = DummyDataGenerator(
self.info.get_hf_config().to_dict()["pretrained_cfg"])
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
@@ -100,29 +98,57 @@ class PrithviGeoSpatialMAEInputBuilder(
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
# This model input is fixed and is in the form of a torch Tensor.
# The size of pixel_values might change in the cases where we resize
# the input but never exceeds the dimensions below.
image_data = {
"pixel_values": torch.full((6, 512, 512), 1.0,
dtype=torch.float16),
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
}
return {"image": image_data}
# Dummy data is generated based on the 'input' section
# defined in the HF configuration file
return self.dummy_data_generator.get_dummy_mm_data()
class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
class TerratorchMultiModalDataParser(MultiModalDataParser):
def __init__(self, pretrained_cfg: dict, *args, **kwargs):
self._pretrained_cfg = pretrained_cfg
super().__init__(*args, **kwargs)
def _parse_image_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict):
terratorch_fields = _terratorch_field_names(self._pretrained_cfg)
return DictEmbeddingItems(
data,
modality="image",
required_fields=terratorch_fields,
fields_factory=_terratorch_field_factory(self._pretrained_cfg),
)
return super()._parse_image_data(data)
class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
def __init__(
self,
info: TerratorchProcessingInfo,
dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]",
*,
cache: Optional[MultiModalProcessorOnlyCache] = None) -> None:
self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache)
def _get_data_parser(self) -> MultiModalDataParser:
return PrithviGeoSpatialMAEMultiModalDataParser()
return TerratorchMultiModalDataParser(
pretrained_cfg=self.pretrained_cfg)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _prithvi_field_config(hf_inputs)
return _terratorch_field_factory(self.pretrained_cfg)(hf_inputs)
def _get_prompt_updates(
self,
@@ -173,13 +199,11 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
@default_pooling_type("All")
@MULTIMODAL_REGISTRY.register_processor(
PrithviGeoSpatialMAEMultiModalProcessor,
info=PrithviGeoSpatialMAEProcessingInfo,
dummy_inputs=PrithviGeoSpatialMAEInputBuilder,
TerratorchMultiModalProcessor,
info=TerratorchProcessingInfo,
dummy_inputs=TerratorchInputBuilder,
)
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
"""Prithvi Masked Autoencoder"""
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
supports_multimodal_raw_input_only = True
is_pooling_model = True
@@ -190,43 +214,13 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
raise ValueError("Only image modality is supported")
def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
# We might be able/need to support different tasks with this same model
if config["task_args"]["task"] == "SemanticSegmentationTask":
from terratorch.cli_tools import SemanticSegmentationTask
task = SemanticSegmentationTask(
config["model_args"],
config["task_args"]["model_factory"],
loss=config["task_args"]["loss"],
lr=config["task_args"]["lr"],
ignore_index=config["task_args"]["ignore_index"],
optimizer=config["task_args"]["optimizer"],
optimizer_hparams=config["optimizer_params"],
scheduler=config["task_args"]["scheduler"],
scheduler_hparams=config["scheduler_params"],
plot_on_val=config["task_args"]["plot_on_val"],
freeze_decoder=config["task_args"]["freeze_decoder"],
freeze_backbone=config["task_args"]["freeze_backbone"],
)
return task.model
else:
return None
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
# the actual model is dynamically instantiated using terratorch
# allowing us to perform changes to the model architecture
# at startup time (e.g., change the model decoder class.)
self.model = self._instantiate_model(
vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"])
if self.model is None:
raise ValueError(
"Unsupported task. "
"Only SemanticSegmentationTask is supported for now "
"by PrithviGeospatialMAE.")
config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]
self.inference_runner = InferenceRunner(config)
self.model = self.inference_runner.model
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
@@ -234,23 +228,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)}, )
def _parse_and_validate_multimodal_data(
self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
pixel_values = kwargs.pop("pixel_values", None)
if not isinstance(pixel_values, torch.Tensor):
raise ValueError(f"Incorrect type of pixel_values. "
f"Got type: {type(pixel_values)}")
location_coords = kwargs.pop("location_coords", None)
if not isinstance(location_coords, torch.Tensor):
raise ValueError(f"Incorrect type of location_coords. "
f"Got type: {type(location_coords)}")
location_coords = torch.unbind(location_coords, dim=0)[0]
if location_coords.shape == torch.Size([0]):
location_coords = None
return pixel_values, location_coords
def get_input_embeddings(
self,
input_ids: torch.Tensor,
@@ -270,10 +247,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
):
pixel_values, location_coords = (
self._parse_and_validate_multimodal_data(**kwargs))
model_output = self.model(pixel_values,
location_coords=location_coords)
model_output = self.inference_runner.forward(**kwargs)
return model_output.output
@@ -283,28 +257,34 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
model_buffers = dict(self.named_buffers())
loaded_buffers = []
for key, value in weights:
if key == "state_dict":
weights_to_parse = value
for name, weight in weights_to_parse.items():
if "pos_embed" in name:
continue
if isinstance(value, (dict, OrderedDict)):
if key == "state_dict":
weights_to_parse = value
for name, weight in weights_to_parse.items():
name = f"inference_runner.{name}"
if "_timm_module." in name:
name = name.replace("_timm_module.", "")
if "pos_embed" in name:
continue
# this model requires a couple of buffers to be loaded
# that are not loadable with the AutoWeightsLoader
if name in model_buffers:
if "_timm_module." in name:
name = name.replace("_timm_module.", "")
buffer = model_buffers[name]
weight_loader = getattr(buffer, "weight_loader",
default_weight_loader)
weight_loader(buffer, weight)
loaded_buffers.append(name)
else:
params_list.append((name, weight))
break
# this model requires a couple of buffers to be loaded
# that are not loadable with the AutoWeightsLoader
if name in model_buffers:
if "_timm_module." in name:
name = name.replace("_timm_module.", "")
buffer = model_buffers[name]
weight_loader = getattr(buffer, "weight_loader",
default_weight_loader)
weight_loader(buffer, weight)
loaded_buffers.append(name)
else:
params_list.append((name, weight))
break
elif isinstance(value, torch.Tensor):
params_list.append((f"inference_runner.model.{key}", value))
# Load the remaining model parameters
loader = AutoWeightsLoader(self)