[Model] Define merge_by_field_config MM interface (R-T) (#26260)
Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -87,12 +87,10 @@ def _terratorch_field_factory(
|
||||
if input.type == InputTypeEnum.tensor:
|
||||
fields[input_name] = "image"
|
||||
|
||||
mm_fields_config = {}
|
||||
for field_name, field_modality in fields.items():
|
||||
mm_fields_config[field_name] = MultiModalFieldConfig.shared(
|
||||
batch_size=1, modality=field_modality
|
||||
)
|
||||
return mm_fields_config
|
||||
return {
|
||||
field_name: MultiModalFieldConfig.batched(modality=field_modality)
|
||||
for field_name, field_modality in fields.items()
|
||||
}
|
||||
|
||||
return _terratorch_field_config
|
||||
|
||||
@@ -192,9 +190,12 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
|
||||
) -> MultiModalInputs:
|
||||
if "image" in mm_data:
|
||||
image_data = mm_data["image"]
|
||||
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
|
||||
else:
|
||||
image_data = mm_data
|
||||
mm_data = {"image": mm_data}
|
||||
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
|
||||
|
||||
mm_data = {"image": image_data}
|
||||
|
||||
mm_items = self._to_mm_items(mm_data)
|
||||
tokenization_kwargs = tokenization_kwargs or {}
|
||||
@@ -226,6 +227,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
|
||||
dummy_inputs=TerratorchInputBuilder,
|
||||
)
|
||||
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
merge_by_field_config = True
|
||||
supports_multimodal_raw_input_only = True
|
||||
is_pooling_model = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user