[Core][Model] PrithviMAE Enablement on vLLM v1 engine (#20577)
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only IBM/NASA Prithvi Geospatial model."""
|
||||
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -27,13 +28,14 @@ from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import (AllPool, PoolerHead,
|
||||
PoolerIdentity, SimplePooler)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (IsAttentionFree,
|
||||
SupportsMultiModal,
|
||||
SupportsV0Only)
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput)
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargs)
|
||||
MultiModalFieldElem, MultiModalInputs,
|
||||
MultiModalKwargs, MultiModalKwargsItem,
|
||||
MultiModalSharedField, PlaceholderRange)
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptUpdate)
|
||||
@@ -62,8 +64,9 @@ class PrithviGeoSpatialMAEInputBuilder(
|
||||
# The size of pixel_values might change in the cases where we resize
|
||||
# the input but never exceeds the dimensions below.
|
||||
return {
|
||||
"pixel_values": torch.full((1, 6, 512, 512), 1.0),
|
||||
"location_coords": torch.full((1, 2), 1.0),
|
||||
"pixel_values": torch.full((6, 512, 512), 1.0,
|
||||
dtype=torch.float16),
|
||||
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
|
||||
}
|
||||
|
||||
|
||||
@@ -75,8 +78,10 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
location_coords=MultiModalFieldConfig.batched("image"),
|
||||
pixel_values=MultiModalFieldConfig.shared(batch_size=1,
|
||||
modality="image"),
|
||||
location_coords=MultiModalFieldConfig.shared(batch_size=1,
|
||||
modality="image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@@ -99,23 +104,48 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
for k, v in mm_data.items():
|
||||
mm_kwargs[k] = v
|
||||
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
|
||||
|
||||
# This model receives in input a multi-dimensional tensor representing
|
||||
# a single image patch and therefore it is not to be split
|
||||
# into multiple elements, but rather to be considered a single one.
|
||||
# Hence, the decision of using a MultiModalSharedField.
|
||||
# The expected shape is (num_channels, width, height).
|
||||
|
||||
# This model however allows the user to also submit multiple image
|
||||
# patches as a batch, adding a further dimension to the above shape.
|
||||
# At this stage we only support submitting one patch per request and
|
||||
# batching is achieved via vLLM batching.
|
||||
# TODO (christian-pinto): enable support for multi patch requests
|
||||
# in tandem with vLLM batching.
|
||||
multimodal_kwargs_items = [
|
||||
MultiModalKwargsItem.from_elems([
|
||||
MultiModalFieldElem(
|
||||
modality="image",
|
||||
key=key,
|
||||
data=data,
|
||||
field=MultiModalSharedField(1),
|
||||
) for key, data in mm_kwargs.items()
|
||||
])
|
||||
]
|
||||
|
||||
return MultiModalInputs(
|
||||
type="multimodal",
|
||||
prompt=prompt,
|
||||
prompt_token_ids=[1],
|
||||
mm_kwargs=MultiModalKwargs(mm_kwargs),
|
||||
mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items),
|
||||
mm_hashes=None,
|
||||
mm_placeholders={},
|
||||
mm_placeholders=mm_placeholders,
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
PrithviGeoSpatialMAEMultiModalProcessor,
|
||||
info=PrithviGeoSpatialMAEProcessingInfo,
|
||||
dummy_inputs=PrithviGeoSpatialMAEInputBuilder)
|
||||
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
|
||||
SupportsV0Only):
|
||||
dummy_inputs=PrithviGeoSpatialMAEInputBuilder,
|
||||
)
|
||||
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree,
|
||||
SupportsMultiModalWithRawInput):
|
||||
"""Prithvi Masked Autoencoder"""
|
||||
|
||||
is_pooling_model = True
|
||||
@@ -128,10 +158,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
|
||||
|
||||
# We might be able/need to support different tasks with this same model
|
||||
if config["task_args"]["task"] == "SemanticSegmentationTask":
|
||||
from terratorch.cli_tools import SemanticSegmentationTask
|
||||
|
||||
task = SemanticSegmentationTask(
|
||||
config["model_args"],
|
||||
config["task_args"]["model_factory"],
|
||||
@@ -144,7 +174,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
|
||||
scheduler_hparams=config["scheduler_params"],
|
||||
plot_on_val=config["task_args"]["plot_on_val"],
|
||||
freeze_decoder=config["task_args"]["freeze_decoder"],
|
||||
freeze_backbone=config["task_args"]["freeze_backbone"])
|
||||
freeze_backbone=config["task_args"]["freeze_backbone"],
|
||||
)
|
||||
|
||||
return task.model
|
||||
else:
|
||||
@@ -168,12 +199,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
|
||||
|
||||
def _parse_and_validate_multimodal_data(
|
||||
self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError(f"Incorrect type of pixel_values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
pixel_values = torch.unbind(pixel_values, dim=0)[0]
|
||||
|
||||
location_coords = kwargs.pop("location_coords", None)
|
||||
if not isinstance(location_coords, torch.Tensor):
|
||||
@@ -185,6 +214,17 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
|
||||
|
||||
return pixel_values, location_coords
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
# We do not really use any input tokens and therefore no embeddings
|
||||
# to be calculated. However, due to the mandatory token ids in
|
||||
# the input prompt we pass one token and the size of the dummy
|
||||
# embedding tensors must reflect that.
|
||||
return torch.empty((input_ids.shape[0], 0))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
|
||||
Reference in New Issue
Block a user