Support Cross encoder models (#10400)
Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Flavia Beo <flavia.beo@ibm.com> Co-authored-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
committed by
GitHub
parent
49628fe13e
commit
214efc2c3c
@@ -6,7 +6,10 @@ import torch.cuda
|
||||
from vllm.model_executor.models import (is_embedding_model,
|
||||
is_text_generation_model,
|
||||
supports_multimodal)
|
||||
from vllm.model_executor.models.registry import (_EMBEDDING_MODELS,
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.model_executor.models.registry import (_CROSS_ENCODER_MODELS,
|
||||
_EMBEDDING_MODELS,
|
||||
_MULTIMODAL_MODELS,
|
||||
_SPECULATIVE_DECODING_MODELS,
|
||||
_TEXT_GENERATION_MODELS,
|
||||
@@ -29,22 +32,28 @@ def test_registry_imports(model_arch):
|
||||
model_arch in _TEXT_GENERATION_MODELS
|
||||
or model_arch in _MULTIMODAL_MODELS)
|
||||
|
||||
embedding_models = {**_EMBEDDING_MODELS, **_CROSS_ENCODER_MODELS}
|
||||
assert is_embedding_model(model_cls) is (model_arch
|
||||
in _EMBEDDING_MODELS)
|
||||
in embedding_models)
|
||||
|
||||
assert supports_multimodal(model_cls) is (model_arch
|
||||
in _MULTIMODAL_MODELS)
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
@pytest.mark.parametrize("model_arch,is_mm,init_cuda", [
|
||||
("LlamaForCausalLM", False, False),
|
||||
("MllamaForConditionalGeneration", True, False),
|
||||
("LlavaForConditionalGeneration", True, True),
|
||||
@pytest.mark.parametrize("model_arch,is_mm,init_cuda,is_ce", [
|
||||
("LlamaForCausalLM", False, False, False),
|
||||
("MllamaForConditionalGeneration", True, False, False),
|
||||
("LlavaForConditionalGeneration", True, True, False),
|
||||
("BertForSequenceClassification", False, False, True),
|
||||
("RobertaForSequenceClassification", False, False, True),
|
||||
("XLMRobertaForSequenceClassification", False, False, True),
|
||||
])
|
||||
def test_registry_is_multimodal(model_arch, is_mm, init_cuda):
|
||||
def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce):
|
||||
assert ModelRegistry.is_multimodal_model(model_arch) is is_mm
|
||||
|
||||
assert ModelRegistry.is_cross_encoder_model(model_arch) is is_ce
|
||||
|
||||
if init_cuda and current_platform.is_cuda_alike():
|
||||
assert not torch.cuda.is_initialized()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user