[Model] Define merge_by_field_config MM interface (R-T) (#26260)

Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Ayush Satyam
2025-10-07 13:40:55 +05:30
committed by GitHub
parent 185d8ed44f
commit de342585ff
3 changed files with 46 additions and 45 deletions

View File

@@ -87,12 +87,10 @@ def _terratorch_field_factory(
if input.type == InputTypeEnum.tensor:
fields[input_name] = "image"
mm_fields_config = {}
for field_name, field_modality in fields.items():
mm_fields_config[field_name] = MultiModalFieldConfig.shared(
batch_size=1, modality=field_modality
)
return mm_fields_config
return {
field_name: MultiModalFieldConfig.batched(modality=field_modality)
for field_name, field_modality in fields.items()
}
return _terratorch_field_config
@@ -192,9 +190,12 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
) -> MultiModalInputs:
if "image" in mm_data:
image_data = mm_data["image"]
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
else:
image_data = mm_data
mm_data = {"image": mm_data}
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
mm_data = {"image": image_data}
mm_items = self._to_mm_items(mm_data)
tokenization_kwargs = tokenization_kwargs or {}
@@ -226,6 +227,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
dummy_inputs=TerratorchInputBuilder,
)
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
merge_by_field_config = True
supports_multimodal_raw_input_only = True
is_pooling_model = True