[New Model]: Support GteNewModelForSequenceClassification (#23524)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2025-08-28 15:36:42 +08:00
committed by GitHub
parent 186aced5ff
commit 11a7fafaa8
13 changed files with 157 additions and 76 deletions

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import pytest
@@ -33,12 +32,15 @@ MODELS = [
########### NewModel
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
architecture="GteNewModel",
hf_overrides={"architectures": ["GteNewModel"]},
enable_test=True),
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
architecture="GteNewModel",
hf_overrides={"architectures": ["GteNewModel"]},
enable_test=True),
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
architecture="GteNewModel",
hf_overrides={"architectures": ["GteNewModel"]},
enable_test=True),
########### Qwen2ForCausalLM
LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
@@ -60,11 +62,16 @@ MODELS = [
]
RERANK_MODELS = [
# classifier_pooling: mean
CLSPoolingRerankModelInfo(
# classifier_pooling: mean
"Alibaba-NLP/gte-reranker-modernbert-base",
architecture="ModernBertForSequenceClassification",
enable_test=True),
CLSPoolingRerankModelInfo(
"Alibaba-NLP/gte-multilingual-reranker-base",
architecture="GteNewForSequenceClassification",
hf_overrides={"architectures": ["GteNewForSequenceClassification"]},
enable_test=True),
]
@@ -75,12 +82,7 @@ def test_embed_models_mteb(hf_runner, vllm_runner,
check_transformers_version(model_info.name,
max_transformers_version="4.53.2")
vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
mteb_test_embed_models(hf_runner, vllm_runner, model_info,
vllm_extra_kwargs)
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
@pytest.mark.parametrize("model_info", MODELS)
@@ -91,12 +93,8 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
check_transformers_version(model_info.name,
max_transformers_version="4.53.2")
vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
correctness_test_embed_models(hf_runner, vllm_runner, model_info,
example_prompts, vllm_extra_kwargs)
example_prompts)
@pytest.mark.parametrize("model_info", RERANK_MODELS)