Add explicit pooling classes for the Transformers backend (#25322)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -9,9 +9,16 @@ from vllm.platforms import current_platform
|
||||
|
||||
from ..conftest import HfRunner, VllmRunner
|
||||
from ..utils import multi_gpu_test, prep_prompts
|
||||
from .registry import HF_EXAMPLE_MODELS
|
||||
from .utils import check_embeddings_close, check_logprobs_close
|
||||
|
||||
|
||||
def get_model(arch: str) -> str:
|
||||
model_info = HF_EXAMPLE_MODELS.get_hf_info(arch)
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
return model_info.default
|
||||
|
||||
|
||||
def check_implementation(
|
||||
runner_ref: type[Union[HfRunner, VllmRunner]],
|
||||
runner_test: type[VllmRunner],
|
||||
@@ -170,71 +177,47 @@ def test_embed_loading(vllm_runner, model):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
# Encoder model
|
||||
"BAAI/bge-base-en-v1.5",
|
||||
])
|
||||
def test_embed_correctness(hf_runner, vllm_runner, example_prompts, model):
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
installed = Version(transformers.__version__)
|
||||
required = Version("4.57.0.dev0")
|
||||
if installed < required:
|
||||
pytest.skip("Encoder models with the Transformers backend require "
|
||||
f"transformers>={required}, but got {installed}")
|
||||
"arch",
|
||||
["TransformersEmbeddingModel", "TransformersForSequenceClassification"])
|
||||
def test_pooling(hf_runner, vllm_runner, example_prompts, arch):
|
||||
model = get_model(arch)
|
||||
|
||||
with vllm_runner(model, max_model_len=512,
|
||||
model_impl="transformers") as vllm_model:
|
||||
vllm_kwargs = dict(
|
||||
max_model_len=None,
|
||||
model_impl="transformers",
|
||||
compilation_config=dict(cudagraph_capture_sizes=[8]),
|
||||
)
|
||||
|
||||
hf_kwargs = dict()
|
||||
if arch == "TransformersEmbeddingModel":
|
||||
hf_kwargs["is_sentence_transformer"] = True
|
||||
elif arch == "TransformersForSequenceClassification":
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
hf_kwargs["auto_cls"] = AutoModelForSequenceClassification
|
||||
|
||||
# The example_prompts has ending "\n", for example:
|
||||
# "Write a short story about a robot that dreams for the first time.\n"
|
||||
# sentence_transformers will strip the input texts, see:
|
||||
# https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159
|
||||
# This makes the input_ids different between hf_model and vllm_model.
|
||||
# So we need to strip the input texts to avoid test failing.
|
||||
example_prompts = [str(s).strip() for s in example_prompts]
|
||||
|
||||
with (vllm_runner(model, **vllm_kwargs) as
|
||||
vllm_model, hf_runner(model, **hf_kwargs) as hf_model):
|
||||
model_config = vllm_model.llm.llm_engine.model_config
|
||||
assert model_config.using_transformers_backend()
|
||||
|
||||
vllm_outputs = vllm_model.embed(example_prompts)
|
||||
|
||||
with hf_runner(model, is_sentence_transformer=True) as hf_model:
|
||||
hf_outputs = hf_model.encode(example_prompts)
|
||||
if arch == "TransformersEmbeddingModel":
|
||||
vllm_outputs = vllm_model.embed(example_prompts)
|
||||
hf_outputs = hf_model.encode(example_prompts)
|
||||
elif arch == "TransformersForSequenceClassification":
|
||||
vllm_outputs = vllm_model.classify(example_prompts)
|
||||
hf_outputs = hf_model.classify(example_prompts)
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_outputs,
|
||||
embeddings_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["jason9693/Qwen2.5-1.5B-apeach"],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_classify(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
with vllm_runner(model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
model_impl="transformers") as vllm_model:
|
||||
model_config = vllm_model.llm.llm_engine.model_config
|
||||
assert model_config.using_transformers_backend()
|
||||
|
||||
vllm_outputs = vllm_model.classify(example_prompts)
|
||||
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
auto_cls=AutoModelForSequenceClassification) as hf_model:
|
||||
hf_outputs = hf_model.classify(example_prompts)
|
||||
|
||||
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
||||
hf_output = torch.tensor(hf_output)
|
||||
vllm_output = torch.tensor(vllm_output)
|
||||
|
||||
assert torch.allclose(hf_output, vllm_output,
|
||||
1e-3 if dtype == "float" else 1e-2)
|
||||
|
||||
Reference in New Issue
Block a user