[Core][Model] Terratorch backend integration (#23513)

Signed-off-by: Michele Gazzetti <michele.gazzetti1@ibm.com>
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
Co-authored-by: Christian Pinto <christian.pinto@ibm.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
mgazz
2025-09-04 08:22:41 +01:00
committed by GitHub
parent e7fc70016f
commit 51d5e9be7d
23 changed files with 305 additions and 208 deletions

View File

@@ -184,10 +184,11 @@ _EMBEDDING_MODELS = {
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
# Technically PrithviGeoSpatialMAE is a model that works on images, both in
# input and output. I am adding it here because it piggybacks on embedding
# Technically Terratorch models work on images, both in
# input and output. I am adding it here because it piggy-backs on embedding
# models for the time being.
"PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
"PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
"Terratorch": ("terratorch", "Terratorch"),
}
_CROSS_ENCODER_MODELS = {
@@ -639,6 +640,9 @@ class _ModelRegistry:
model_info = self._try_inspect_model_cls(arch)
if model_info is not None:
return (model_info, arch)
elif model_config.model_impl == ModelImpl.TERRATORCH:
model_info = self._try_inspect_model_cls("Terratorch")
return (model_info, "Terratorch")
# Fallback to transformers impl (after resolving convert_type)
if (all(arch not in self.models for arch in architectures)
@@ -687,6 +691,11 @@ class _ModelRegistry:
model_cls = self._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
elif model_config.model_impl == ModelImpl.TERRATORCH:
arch = "Terratorch"
model_cls = self._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
# Fallback to transformers impl (after resolving convert_type)
if (all(arch not in self.models for arch in architectures)