[Model] Support SigLIP encoder and alternative decoders for LLaVA models (#7153)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import functools
|
||||
import importlib
|
||||
from typing import Dict, List, Optional, Type
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -126,7 +126,7 @@ class ModelRegistry:
|
||||
return getattr(module, model_cls_name, None)
|
||||
|
||||
@staticmethod
|
||||
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||
if model_arch in _OOT_MODELS:
|
||||
return _OOT_MODELS[model_arch]
|
||||
if model_arch not in _MODELS:
|
||||
@@ -143,6 +143,18 @@ class ModelRegistry:
|
||||
|
||||
return ModelRegistry._get_model(model_arch)
|
||||
|
||||
@staticmethod
|
||||
def resolve_model_cls(
|
||||
architectures: List[str]) -> Tuple[Type[nn.Module], str]:
|
||||
for arch in architectures:
|
||||
model_cls = ModelRegistry._try_load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
return (model_cls, arch)
|
||||
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||
|
||||
@staticmethod
|
||||
def get_supported_archs() -> List[str]:
|
||||
return list(_MODELS.keys())
|
||||
|
||||
Reference in New Issue
Block a user