[Model] Explicit interface for vLLM models and support OOT embedding models (#9108)
This commit is contained in:
@@ -3,7 +3,14 @@ import warnings
|
||||
import pytest
|
||||
import torch.cuda
|
||||
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.models import (is_embedding_model,
|
||||
is_text_generation_model,
|
||||
supports_multimodal)
|
||||
from vllm.model_executor.models.registry import (_EMBEDDING_MODELS,
|
||||
_MULTIMODAL_MODELS,
|
||||
_SPECULATIVE_DECODING_MODELS,
|
||||
_TEXT_GENERATION_MODELS,
|
||||
ModelRegistry)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import fork_new_process_for_each_test
|
||||
@@ -12,7 +19,20 @@ from ..utils import fork_new_process_for_each_test
|
||||
@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
|
||||
def test_registry_imports(model_arch):
|
||||
# Ensure all model classes can be imported successfully
|
||||
ModelRegistry.resolve_model_cls(model_arch)
|
||||
model_cls, _ = ModelRegistry.resolve_model_cls(model_arch)
|
||||
|
||||
if model_arch in _SPECULATIVE_DECODING_MODELS:
|
||||
pass # Ignore these models which do not have a unified format
|
||||
else:
|
||||
assert is_text_generation_model(model_cls) is (
|
||||
model_arch in _TEXT_GENERATION_MODELS
|
||||
or model_arch in _MULTIMODAL_MODELS)
|
||||
|
||||
assert is_embedding_model(model_cls) is (model_arch
|
||||
in _EMBEDDING_MODELS)
|
||||
|
||||
assert supports_multimodal(model_cls) is (model_arch
|
||||
in _MULTIMODAL_MODELS)
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
|
||||
Reference in New Issue
Block a user