[VLM] Separate text-only and vision variants of the same model architecture (#13157)

This commit is contained in:
Cyrus Leung
2025-02-13 22:19:15 +08:00
committed by GitHub
parent 02ed8a1fbe
commit 1bc3b5e71b
15 changed files with 1728 additions and 1642 deletions

View File

@@ -4,12 +4,14 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import torch
from PIL.Image import Image
from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizerBase
from transformers import BatchEncoding
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm.config import TaskOption
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .....conftest import HfRunner, VllmRunner
from ....registry import HF_EXAMPLE_MODELS
from .types import RunnerOutput
@@ -31,10 +33,8 @@ def run_test(
use_tokenizer_eos: bool,
postprocess_inputs: Callable[[BatchEncoding], BatchEncoding],
comparator: Callable[..., None],
get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase],
List[int]]],
get_stop_token_ids: Optional[Callable[[AnyTokenizer], list[int]]],
stop_str: Optional[List[str]],
tokenizer_mode: str,
limit_mm_per_prompt: Dict[str, int],
vllm_runner_kwargs: Optional[Dict[str, Any]],
hf_model_kwargs: Optional[Dict[str, Any]],
@@ -48,7 +48,10 @@ def run_test(
"""Modality agnostic test test executor for comparing HF/vLLM outputs."""
# In the case of embeddings, vLLM takes separate input tensors
vllm_inputs = vllm_embeddings if vllm_embeddings is not None else inputs
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
vllm_outputs_per_mm = []
hf_outputs_per_mm = []
@@ -57,17 +60,19 @@ def run_test(
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
vllm_kwargs: Dict[str, Any] = {}
if get_stop_token_ids is not None:
vllm_kwargs["stop_token_ids"] = get_stop_token_ids(tokenizer)
if stop_str:
vllm_kwargs["stop"] = stop_str
if vllm_runner_kwargs is None:
vllm_runner_kwargs = {}
vllm_runner_kwargs_: Dict[str, Any] = {}
if model_info.tokenizer:
vllm_runner_kwargs_["tokenizer"] = model_info.tokenizer
if model_info.tokenizer_mode:
vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode
if model_info.hf_overrides:
vllm_runner_kwargs_["hf_overrides"] = model_info.hf_overrides
if vllm_runner_kwargs:
vllm_runner_kwargs_.update(vllm_runner_kwargs)
with vllm_runner(model,
tokenizer_mode=tokenizer_mode,
max_model_len=max_model_len,
max_num_seqs=max_num_seqs,
dtype=dtype,
@@ -76,7 +81,15 @@ def run_test(
distributed_executor_backend=distributed_executor_backend,
enforce_eager=enforce_eager,
task=task,
**vllm_runner_kwargs) as vllm_model:
**vllm_runner_kwargs_) as vllm_model:
tokenizer = vllm_model.model.get_tokenizer()
vllm_kwargs: Dict[str, Any] = {}
if get_stop_token_ids is not None:
vllm_kwargs["stop_token_ids"] = get_stop_token_ids(tokenizer)
if stop_str:
vllm_kwargs["stop"] = stop_str
for prompts, media in vllm_inputs:
vllm_kwargs[runner_mm_key] = media
vllm_output = vllm_model.generate_greedy_logprobs(
@@ -93,16 +106,19 @@ def run_test(
if patch_hf_runner is not None:
hf_model = patch_hf_runner(hf_model)
# Some models need to explicitly pass the eos_token_id off the tokenizer or
# processor for a good comparison; currently assume processor/tokenizer
# agree on the EOS, and pull it off the tokenizer if requested.
hf_kwargs = {}
if use_tokenizer_eos:
hf_kwargs["eos_token_id"] = tokenizer.eos_token_id
if stop_str:
hf_kwargs["stop_strings"] = stop_str
with hf_model, torch.no_grad():
tokenizer = hf_model.tokenizer
# Some models need to explicitly pass the eos_token_id off the tokenizer
# or processor for a good comparison;
# currently assume processor/tokenizer agree on the EOS, and pull it off
# the tokenizer if requested.
hf_kwargs = {}
if use_tokenizer_eos:
hf_kwargs["eos_token_id"] = tokenizer.eos_token_id
if stop_str:
hf_kwargs["stop_strings"] = stop_str
for prompts, media in inputs:
hf_kwargs[runner_mm_key] = media
hf_output = hf_model.generate_greedy_logprobs_limit(

View File

@@ -8,12 +8,12 @@ from typing import (Any, Callable, Dict, Iterable, List, NamedTuple, Optional,
import torch
from PIL.Image import Image
from pytest import MarkDecorator
from transformers import (AutoModelForCausalLM, BatchEncoding,
PreTrainedTokenizerBase)
from transformers import AutoModelForCausalLM, BatchEncoding
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm.config import TaskOption
from vllm.sequence import SampleLogprobs
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import identity
from .....conftest import IMAGE_ASSETS, HfRunner, ImageAsset, _ImageAssets
@@ -100,8 +100,7 @@ class VLMTestInfo(NamedTuple):
vllm_runner_kwargs: Optional[Dict[str, Any]] = None
# Optional callable which gets a list of token IDs from the model tokenizer
get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase],
List[int]]] = None
get_stop_token_ids: Optional[Callable[[AnyTokenizer], list[int]]] = None
# Optional list of strings to stop generation, useful when stop tokens are
# not special tokens in the tokenizer
stop_str: Optional[List[str]] = None
@@ -156,8 +155,6 @@ class VLMTestInfo(NamedTuple):
marks: Optional[List[MarkDecorator]] = None
tokenizer_mode: str = "auto"
def get_non_parametrized_runner_kwargs(self):
"""Returns a dictionary of expandable kwargs for items that are used
in all test types, which are NOT used when creating the parametrized
@@ -180,7 +177,6 @@ class VLMTestInfo(NamedTuple):
"hf_model_kwargs": self.hf_model_kwargs,
"stop_str": self.stop_str,
"patch_hf_runner": self.patch_hf_runner,
"tokenizer_mode": self.tokenizer_mode
}