Support Cross encoder models (#10400)

Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Flavia Beo <flavia.beo@ibm.com>
Co-authored-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
Maximilien de Bayser
2024-11-24 23:56:20 -03:00
committed by GitHub
parent 49628fe13e
commit 214efc2c3c
28 changed files with 1370 additions and 62 deletions

View File

@@ -6,7 +6,10 @@ import torch.cuda
from vllm.model_executor.models import (is_embedding_model,
is_text_generation_model,
supports_multimodal)
from vllm.model_executor.models.registry import (_EMBEDDING_MODELS,
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.models.registry import (_CROSS_ENCODER_MODELS,
_EMBEDDING_MODELS,
_MULTIMODAL_MODELS,
_SPECULATIVE_DECODING_MODELS,
_TEXT_GENERATION_MODELS,
@@ -29,22 +32,28 @@ def test_registry_imports(model_arch):
model_arch in _TEXT_GENERATION_MODELS
or model_arch in _MULTIMODAL_MODELS)
embedding_models = {**_EMBEDDING_MODELS, **_CROSS_ENCODER_MODELS}
assert is_embedding_model(model_cls) is (model_arch
in _EMBEDDING_MODELS)
in embedding_models)
assert supports_multimodal(model_cls) is (model_arch
in _MULTIMODAL_MODELS)
@fork_new_process_for_each_test
@pytest.mark.parametrize("model_arch,is_mm,init_cuda", [
("LlamaForCausalLM", False, False),
("MllamaForConditionalGeneration", True, False),
("LlavaForConditionalGeneration", True, True),
@pytest.mark.parametrize("model_arch,is_mm,init_cuda,is_ce", [
("LlamaForCausalLM", False, False, False),
("MllamaForConditionalGeneration", True, False, False),
("LlavaForConditionalGeneration", True, True, False),
("BertForSequenceClassification", False, False, True),
("RobertaForSequenceClassification", False, False, True),
("XLMRobertaForSequenceClassification", False, False, True),
])
def test_registry_is_multimodal(model_arch, is_mm, init_cuda):
def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce):
assert ModelRegistry.is_multimodal_model(model_arch) is is_mm
assert ModelRegistry.is_cross_encoder_model(model_arch) is is_ce
if init_cuda and current_platform.is_cuda_alike():
assert not torch.cuda.is_initialized()