[Model] Consolidate score logic by introduce score_type (#36479)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
wang.yuqi
2026-03-10 21:32:25 +08:00
committed by GitHub
parent 409c4e632d
commit a3189a08b0
14 changed files with 213 additions and 194 deletions

View File

@@ -56,21 +56,24 @@ def test_registry_imports(model_arch):
@create_new_process_for_each_test()
@pytest.mark.parametrize(
"model_arch,is_mm,init_cuda,is_ce",
"model_arch,is_mm,init_cuda,score_type",
[
("LlamaForCausalLM", False, False, False),
("LlavaForConditionalGeneration", True, True, False),
("BertForSequenceClassification", False, False, True),
("RobertaForSequenceClassification", False, False, True),
("XLMRobertaForSequenceClassification", False, False, True),
("LlamaForCausalLM", False, False, "bi-encoder"),
("LlavaForConditionalGeneration", True, True, "bi-encoder"),
("BertForSequenceClassification", False, False, "cross-encoder"),
("RobertaForSequenceClassification", False, False, "cross-encoder"),
("XLMRobertaForSequenceClassification", False, False, "cross-encoder"),
("GteNewModel", False, False, "bi-encoder"),
("GteNewForSequenceClassification", False, False, "cross-encoder"),
("HF_ColBERT", False, False, "late-interaction"),
],
)
def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce):
def test_registry_model_property(model_arch, is_mm, init_cuda, score_type):
model_info = ModelRegistry._try_inspect_model_cls(model_arch)
assert model_info is not None
assert model_info.supports_multimodal is is_mm
assert model_info.supports_cross_encoding is is_ce
assert model_info.score_type == score_type
if init_cuda and current_platform.is_cuda_alike():
assert not torch.cuda.is_initialized()