[Misc] Small perf improvements (#6520)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import importlib
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
@@ -98,6 +99,14 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
|
||||
|
||||
class ModelRegistry:
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _get_model(model_arch: str):
|
||||
module_name, model_cls_name = _MODELS[model_arch]
|
||||
module = importlib.import_module(
|
||||
f"vllm.model_executor.models.{module_name}")
|
||||
return getattr(module, model_cls_name, None)
|
||||
|
||||
@staticmethod
|
||||
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||
if model_arch in _OOT_MODELS:
|
||||
@@ -114,10 +123,7 @@ class ModelRegistry:
|
||||
"Model architecture %s is partially supported by ROCm: %s",
|
||||
model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
|
||||
|
||||
module_name, model_cls_name = _MODELS[model_arch]
|
||||
module = importlib.import_module(
|
||||
f"vllm.model_executor.models.{module_name}")
|
||||
return getattr(module, model_cls_name, None)
|
||||
return ModelRegistry._get_model(model_arch)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_archs() -> List[str]:
|
||||
|
||||
Reference in New Issue
Block a user