[Model] Support SigLIP encoder and alternative decoders for LLaVA models (#7153)

Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
Cyrus Leung
2024-08-06 16:55:31 +08:00
committed by GitHub
parent 9118217f58
commit 1f26efbb3a
14 changed files with 453 additions and 267 deletions

View File

@@ -1,6 +1,6 @@
import functools
import importlib
from typing import Dict, List, Optional, Type
from typing import Dict, List, Optional, Tuple, Type
import torch.nn as nn
@@ -126,7 +126,7 @@ class ModelRegistry:
return getattr(module, model_cls_name, None)
@staticmethod
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
return _OOT_MODELS[model_arch]
if model_arch not in _MODELS:
@@ -143,6 +143,18 @@ class ModelRegistry:
return ModelRegistry._get_model(model_arch)
@staticmethod
def resolve_model_cls(
architectures: List[str]) -> Tuple[Type[nn.Module], str]:
for arch in architectures:
model_cls = ModelRegistry._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
@staticmethod
def get_supported_archs() -> List[str]:
return list(_MODELS.keys())