[Misc] Small perf improvements (#6520)

This commit is contained in:
Antoni Baum
2024-07-19 12:10:56 -07:00
committed by GitHub
parent 51f8aa90ad
commit 9ed82e7074
7 changed files with 46 additions and 23 deletions

View File

@@ -1,3 +1,4 @@
import functools
import importlib
from typing import Dict, List, Optional, Type
@@ -98,6 +99,14 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
class ModelRegistry:
@staticmethod
@functools.lru_cache(maxsize=128)
def _get_model(model_arch: str):
module_name, model_cls_name = _MODELS[model_arch]
module = importlib.import_module(
f"vllm.model_executor.models.{module_name}")
return getattr(module, model_cls_name, None)
@staticmethod
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
@@ -114,10 +123,7 @@ class ModelRegistry:
"Model architecture %s is partially supported by ROCm: %s",
model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
module_name, model_cls_name = _MODELS[model_arch]
module = importlib.import_module(
f"vllm.model_executor.models.{module_name}")
return getattr(module, model_cls_name, None)
return ModelRegistry._get_model(model_arch)
@staticmethod
def get_supported_archs() -> List[str]: