[Core][Model] Terratorch backend integration (#23513)
Signed-off-by: Michele Gazzetti <michele.gazzetti1@ibm.com> Signed-off-by: Christian Pinto <christian.pinto@ibm.com> Co-authored-by: Christian Pinto <christian.pinto@ibm.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -184,10 +184,11 @@ _EMBEDDING_MODELS = {
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
||||
# Technically PrithviGeoSpatialMAE is a model that works on images, both in
|
||||
# input and output. I am adding it here because it piggybacks on embedding
|
||||
# Technically Terratorch models work on images, both in
|
||||
# input and output. I am adding it here because it piggy-backs on embedding
|
||||
# models for the time being.
|
||||
"PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
|
||||
"PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
|
||||
"Terratorch": ("terratorch", "Terratorch"),
|
||||
}
|
||||
|
||||
_CROSS_ENCODER_MODELS = {
|
||||
@@ -639,6 +640,9 @@ class _ModelRegistry:
|
||||
model_info = self._try_inspect_model_cls(arch)
|
||||
if model_info is not None:
|
||||
return (model_info, arch)
|
||||
elif model_config.model_impl == ModelImpl.TERRATORCH:
|
||||
model_info = self._try_inspect_model_cls("Terratorch")
|
||||
return (model_info, "Terratorch")
|
||||
|
||||
# Fallback to transformers impl (after resolving convert_type)
|
||||
if (all(arch not in self.models for arch in architectures)
|
||||
@@ -687,6 +691,11 @@ class _ModelRegistry:
|
||||
model_cls = self._try_load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
return (model_cls, arch)
|
||||
elif model_config.model_impl == ModelImpl.TERRATORCH:
|
||||
arch = "Terratorch"
|
||||
model_cls = self._try_load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
return (model_cls, arch)
|
||||
|
||||
# Fallback to transformers impl (after resolving convert_type)
|
||||
if (all(arch not in self.models for arch in architectures)
|
||||
|
||||
@@ -15,13 +15,16 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only IBM/NASA Prithvi Geospatial model."""
|
||||
"""Wrapper around `Terratorch` models"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from terratorch.vllm import (DummyDataGenerator, InferenceRunner,
|
||||
InputDefinition, InputTypeEnum)
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
@@ -29,6 +32,7 @@ from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
|
||||
from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
||||
MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargsItems,
|
||||
@@ -45,52 +49,46 @@ from .interfaces import (IsAttentionFree, MultiModalEmbeddings,
|
||||
from .interfaces_base import default_pooling_type
|
||||
|
||||
|
||||
def _prithvi_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
# This model receives in input a multi-dimensional tensor representing
|
||||
# a single image patch and therefore it is not to be split
|
||||
# into multiple elements, but rather to be considered a single one.
|
||||
# Hence, the decision of using a MultiModalSharedField.
|
||||
# The expected shape is (num_channels, width, height).
|
||||
|
||||
# This model however allows the user to also submit multiple image
|
||||
# patches as a batch, adding a further dimension to the above shape.
|
||||
# At this stage we only support submitting one patch per request and
|
||||
# batching is achieved via vLLM batching.
|
||||
# TODO (christian-pinto): enable support for multi patch requests
|
||||
# in tandem with vLLM batching.
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.shared(batch_size=1,
|
||||
modality="image"),
|
||||
location_coords=MultiModalFieldConfig.shared(batch_size=1,
|
||||
modality="image"),
|
||||
)
|
||||
def _terratorch_field_names(pretrained_cfg: dict):
|
||||
input_definition = InputDefinition(**pretrained_cfg["input"])
|
||||
return set(input_definition.data.keys())
|
||||
|
||||
|
||||
class PrithviGeoSpatialMAEMultiModalDataParser(MultiModalDataParser):
|
||||
def _terratorch_field_factory(
|
||||
pretrained_cfg: dict
|
||||
) -> Callable[
|
||||
[Mapping[str, torch.Tensor]],
|
||||
Mapping[str, MultiModalFieldConfig],
|
||||
]:
|
||||
|
||||
def _parse_image_data(
|
||||
self,
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
|
||||
) -> Optional[ModalityDataItems[Any, Any]]:
|
||||
if isinstance(data, dict):
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
modality="image",
|
||||
required_fields={"pixel_values", "location_coords"},
|
||||
fields_factory=_prithvi_field_config,
|
||||
)
|
||||
def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
input_definition = InputDefinition(**pretrained_cfg["input"])
|
||||
fields = {}
|
||||
for input_name, input in input_definition.data.items():
|
||||
if input.type == InputTypeEnum.tensor:
|
||||
fields[input_name] = "image"
|
||||
|
||||
return super()._parse_image_data(data)
|
||||
mm_fields_config = {}
|
||||
for field_name, field_modality in fields.items():
|
||||
mm_fields_config[field_name] = MultiModalFieldConfig.shared(
|
||||
batch_size=1, modality=field_modality)
|
||||
return mm_fields_config
|
||||
|
||||
return _terratorch_field_config
|
||||
|
||||
|
||||
class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
|
||||
class TerratorchProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
|
||||
class PrithviGeoSpatialMAEInputBuilder(
|
||||
BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]):
|
||||
class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
|
||||
|
||||
def __init__(self, info: TerratorchProcessingInfo):
|
||||
super().__init__(info)
|
||||
self.dummy_data_generator = DummyDataGenerator(
|
||||
self.info.get_hf_config().to_dict()["pretrained_cfg"])
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
return ""
|
||||
@@ -100,29 +98,57 @@ class PrithviGeoSpatialMAEInputBuilder(
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalDataDict:
|
||||
# This model input is fixed and is in the form of a torch Tensor.
|
||||
# The size of pixel_values might change in the cases where we resize
|
||||
# the input but never exceeds the dimensions below.
|
||||
image_data = {
|
||||
"pixel_values": torch.full((6, 512, 512), 1.0,
|
||||
dtype=torch.float16),
|
||||
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
|
||||
}
|
||||
|
||||
return {"image": image_data}
|
||||
# Dummy data is generated based on the 'input' section
|
||||
# defined in the HF configuration file
|
||||
return self.dummy_data_generator.get_dummy_mm_data()
|
||||
|
||||
|
||||
class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
||||
class TerratorchMultiModalDataParser(MultiModalDataParser):
|
||||
|
||||
def __init__(self, pretrained_cfg: dict, *args, **kwargs):
|
||||
self._pretrained_cfg = pretrained_cfg
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _parse_image_data(
|
||||
self,
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
|
||||
) -> Optional[ModalityDataItems[Any, Any]]:
|
||||
if isinstance(data, dict):
|
||||
|
||||
terratorch_fields = _terratorch_field_names(self._pretrained_cfg)
|
||||
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
modality="image",
|
||||
required_fields=terratorch_fields,
|
||||
fields_factory=_terratorch_field_factory(self._pretrained_cfg),
|
||||
)
|
||||
|
||||
return super()._parse_image_data(data)
|
||||
|
||||
|
||||
class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
info: TerratorchProcessingInfo,
|
||||
dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]",
|
||||
*,
|
||||
cache: Optional[MultiModalProcessorOnlyCache] = None) -> None:
|
||||
|
||||
self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
|
||||
super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache)
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return PrithviGeoSpatialMAEMultiModalDataParser()
|
||||
return TerratorchMultiModalDataParser(
|
||||
pretrained_cfg=self.pretrained_cfg)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return _prithvi_field_config(hf_inputs)
|
||||
return _terratorch_field_factory(self.pretrained_cfg)(hf_inputs)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
@@ -173,13 +199,11 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
@default_pooling_type("All")
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
PrithviGeoSpatialMAEMultiModalProcessor,
|
||||
info=PrithviGeoSpatialMAEProcessingInfo,
|
||||
dummy_inputs=PrithviGeoSpatialMAEInputBuilder,
|
||||
TerratorchMultiModalProcessor,
|
||||
info=TerratorchProcessingInfo,
|
||||
dummy_inputs=TerratorchInputBuilder,
|
||||
)
|
||||
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
"""Prithvi Masked Autoencoder"""
|
||||
|
||||
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
supports_multimodal_raw_input_only = True
|
||||
is_pooling_model = True
|
||||
|
||||
@@ -190,43 +214,13 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
|
||||
# We might be able/need to support different tasks with this same model
|
||||
if config["task_args"]["task"] == "SemanticSegmentationTask":
|
||||
from terratorch.cli_tools import SemanticSegmentationTask
|
||||
|
||||
task = SemanticSegmentationTask(
|
||||
config["model_args"],
|
||||
config["task_args"]["model_factory"],
|
||||
loss=config["task_args"]["loss"],
|
||||
lr=config["task_args"]["lr"],
|
||||
ignore_index=config["task_args"]["ignore_index"],
|
||||
optimizer=config["task_args"]["optimizer"],
|
||||
optimizer_hparams=config["optimizer_params"],
|
||||
scheduler=config["task_args"]["scheduler"],
|
||||
scheduler_hparams=config["scheduler_params"],
|
||||
plot_on_val=config["task_args"]["plot_on_val"],
|
||||
freeze_decoder=config["task_args"]["freeze_decoder"],
|
||||
freeze_backbone=config["task_args"]["freeze_backbone"],
|
||||
)
|
||||
|
||||
return task.model
|
||||
else:
|
||||
return None
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
# the actual model is dynamically instantiated using terratorch
|
||||
# allowing us to perform changes to the model architecture
|
||||
# at startup time (e.g., change the model decoder class.)
|
||||
self.model = self._instantiate_model(
|
||||
vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"])
|
||||
if self.model is None:
|
||||
raise ValueError(
|
||||
"Unsupported task. "
|
||||
"Only SemanticSegmentationTask is supported for now "
|
||||
"by PrithviGeospatialMAE.")
|
||||
config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]
|
||||
|
||||
self.inference_runner = InferenceRunner(config)
|
||||
self.model = self.inference_runner.model
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
@@ -234,23 +228,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
self.pooler = DispatchPooler(
|
||||
{"encode": Pooler.for_encode(pooler_config)}, )
|
||||
|
||||
def _parse_and_validate_multimodal_data(
|
||||
self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError(f"Incorrect type of pixel_values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
location_coords = kwargs.pop("location_coords", None)
|
||||
if not isinstance(location_coords, torch.Tensor):
|
||||
raise ValueError(f"Incorrect type of location_coords. "
|
||||
f"Got type: {type(location_coords)}")
|
||||
location_coords = torch.unbind(location_coords, dim=0)[0]
|
||||
if location_coords.shape == torch.Size([0]):
|
||||
location_coords = None
|
||||
|
||||
return pixel_values, location_coords
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -270,10 +247,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
):
|
||||
pixel_values, location_coords = (
|
||||
self._parse_and_validate_multimodal_data(**kwargs))
|
||||
model_output = self.model(pixel_values,
|
||||
location_coords=location_coords)
|
||||
model_output = self.inference_runner.forward(**kwargs)
|
||||
|
||||
return model_output.output
|
||||
|
||||
@@ -283,28 +257,34 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
model_buffers = dict(self.named_buffers())
|
||||
loaded_buffers = []
|
||||
for key, value in weights:
|
||||
if key == "state_dict":
|
||||
weights_to_parse = value
|
||||
for name, weight in weights_to_parse.items():
|
||||
if "pos_embed" in name:
|
||||
continue
|
||||
if isinstance(value, (dict, OrderedDict)):
|
||||
if key == "state_dict":
|
||||
weights_to_parse = value
|
||||
for name, weight in weights_to_parse.items():
|
||||
name = f"inference_runner.{name}"
|
||||
|
||||
if "_timm_module." in name:
|
||||
name = name.replace("_timm_module.", "")
|
||||
if "pos_embed" in name:
|
||||
continue
|
||||
|
||||
# this model requires a couple of buffers to be loaded
|
||||
# that are not loadable with the AutoWeightsLoader
|
||||
if name in model_buffers:
|
||||
if "_timm_module." in name:
|
||||
name = name.replace("_timm_module.", "")
|
||||
buffer = model_buffers[name]
|
||||
weight_loader = getattr(buffer, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(buffer, weight)
|
||||
loaded_buffers.append(name)
|
||||
else:
|
||||
params_list.append((name, weight))
|
||||
break
|
||||
|
||||
# this model requires a couple of buffers to be loaded
|
||||
# that are not loadable with the AutoWeightsLoader
|
||||
if name in model_buffers:
|
||||
if "_timm_module." in name:
|
||||
name = name.replace("_timm_module.", "")
|
||||
buffer = model_buffers[name]
|
||||
weight_loader = getattr(buffer, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(buffer, weight)
|
||||
loaded_buffers.append(name)
|
||||
else:
|
||||
params_list.append((name, weight))
|
||||
break
|
||||
|
||||
elif isinstance(value, torch.Tensor):
|
||||
params_list.append((f"inference_runner.model.{key}", value))
|
||||
|
||||
# Load the remaining model parameters
|
||||
loader = AutoWeightsLoader(self)
|
||||
Reference in New Issue
Block a user