[Core] Update dtype detection and defaults (#14858)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -4,7 +4,6 @@ from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from transformers import BatchEncoding
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
from vllm.config import TaskOption
|
||||
@@ -31,7 +30,6 @@ def run_test(
|
||||
vllm_output_post_proc: Optional[Callable[[RunnerOutput, str], Any]],
|
||||
auto_cls: type[_BaseAutoModelClass],
|
||||
use_tokenizer_eos: bool,
|
||||
postprocess_inputs: Callable[[BatchEncoding], BatchEncoding],
|
||||
comparator: Callable[..., None],
|
||||
get_stop_token_ids: Optional[Callable[[AnyTokenizer], list[int]]],
|
||||
stop_str: Optional[list[str]],
|
||||
@@ -101,7 +99,6 @@ def run_test(
|
||||
hf_model = hf_runner(model,
|
||||
dtype=dtype,
|
||||
auto_cls=auto_cls,
|
||||
postprocess_inputs=postprocess_inputs,
|
||||
model_kwargs=hf_model_kwargs)
|
||||
|
||||
# Some models need to patch things like the model processor, e.g., internvl
|
||||
|
||||
@@ -6,16 +6,15 @@ typically specific to a small subset of models.
|
||||
import re
|
||||
import types
|
||||
from pathlib import PosixPath
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from transformers import (AutoConfig, AutoTokenizer, BatchEncoding,
|
||||
from transformers import (AutoConfig, AutoTokenizer, BatchFeature,
|
||||
GenerationConfig)
|
||||
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.transformers_utils.tokenizer import patch_padding_side
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
from .....conftest import HfRunner, ImageAsset, _ImageAssets
|
||||
from .types import RunnerOutput
|
||||
@@ -211,40 +210,6 @@ def get_llava_embeddings(image_assets: _ImageAssets):
|
||||
return [asset.image_embeds for asset in image_assets]
|
||||
|
||||
|
||||
####### postprocessors to run on HF BatchEncoding
|
||||
def cast_dtype_post_processor(
|
||||
hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]:
|
||||
"""Gets a handle to a post processor which converts a given key into a
|
||||
target data type."""
|
||||
|
||||
def process(hf_inputs: BatchEncoding, dtype: str):
|
||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
hf_inputs[hf_inp_key] = hf_inputs[hf_inp_key].to(torch_dtype)
|
||||
return hf_inputs
|
||||
|
||||
return process
|
||||
|
||||
|
||||
def ignore_inputs_post_processor(
|
||||
hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]:
|
||||
"""Gets a handle to a post processor which ignores a given key."""
|
||||
|
||||
def process(hf_inputs: BatchEncoding, dtype: str):
|
||||
del hf_inputs[hf_inp_key]
|
||||
return hf_inputs
|
||||
|
||||
return process
|
||||
|
||||
|
||||
def wrap_inputs_post_processor(hf_inputs: BatchEncoding, dtype: str):
|
||||
return {"model_inputs": hf_inputs}
|
||||
|
||||
|
||||
def molmo_post_processor(hf_inputs: BatchEncoding, dtype: str):
|
||||
hf_inputs = cast_dtype_post_processor("images")(hf_inputs, dtype)
|
||||
return {k: v.unsqueeze(0) for k, v in hf_inputs.items()}
|
||||
|
||||
|
||||
####### Prompt path encoders for models that need models on disk
|
||||
def qwen_prompt_path_encoder(
|
||||
tmp_path: PosixPath, prompt: str, assets: Union[list[ImageAsset],
|
||||
@@ -295,8 +260,7 @@ def deepseekvl2_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
for k in inputs.keys() # noqa
|
||||
if k not in ("seq_lens", "sft_format")
|
||||
}
|
||||
inputs = BatchEncoding(data=inputs, tensor_type="pt")
|
||||
return inputs
|
||||
return BatchFeature(data=inputs, tensor_type="pt")
|
||||
|
||||
hf_model.processor = processor
|
||||
hf_model.model.get_output_embeddings = lambda: \
|
||||
@@ -529,10 +493,52 @@ def mantis_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
return hf_model
|
||||
|
||||
|
||||
def minicpmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
def minicpmv_25_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
orig_generate = hf_model.model.generate
|
||||
|
||||
def _generate(self, *args, **kwargs):
|
||||
def _generate(
|
||||
self,
|
||||
*args,
|
||||
input_ids=None,
|
||||
pixel_values=None,
|
||||
image_sizes=None,
|
||||
image_bound=None,
|
||||
tgt_sizes=None,
|
||||
**kwargs,
|
||||
):
|
||||
model_inputs = {
|
||||
"input_ids": input_ids,
|
||||
"pixel_values": pixel_values,
|
||||
"image_sizes": image_sizes,
|
||||
"image_bound": image_bound,
|
||||
"tgt_sizes": tgt_sizes,
|
||||
}
|
||||
for k in list(model_inputs.keys()):
|
||||
if model_inputs[k] is None:
|
||||
model_inputs.pop(k)
|
||||
|
||||
return orig_generate(model_inputs, *args, decode_text=False, **kwargs)
|
||||
|
||||
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
|
||||
|
||||
return hf_model
|
||||
|
||||
|
||||
def minicpmo_26_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
orig_generate = hf_model.model.generate
|
||||
|
||||
def _generate(self, *args, image_sizes=None, **kwargs):
|
||||
return orig_generate(*args, decode_text=False, **kwargs)
|
||||
|
||||
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
|
||||
|
||||
return hf_model
|
||||
|
||||
|
||||
def minicpmv_26_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
orig_generate = hf_model.model.generate
|
||||
|
||||
def _generate(self, *args, image_sizes=None, **kwargs):
|
||||
return orig_generate(*args, decode_text=False, **kwargs)
|
||||
|
||||
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
|
||||
@@ -551,10 +557,11 @@ def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
|
||||
def _generate(self, max_new_tokens=None, do_sample=None, **kwargs):
|
||||
batch = {
|
||||
k: kwargs.pop(k)
|
||||
k: kwargs.pop(k).unsqueeze(0)
|
||||
for k in ("input_ids", "images", "image_input_idx", "image_masks")
|
||||
if k in kwargs
|
||||
}
|
||||
batch = BatchFeature(batch).to(dtype=self.dtype)
|
||||
|
||||
return self.generate_from_batch(
|
||||
batch,
|
||||
|
||||
@@ -8,13 +8,12 @@ from typing import Any, Callable, NamedTuple, Optional, Union
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from pytest import MarkDecorator
|
||||
from transformers import AutoModelForCausalLM, BatchEncoding
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
from vllm.config import TaskOption
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import identity
|
||||
|
||||
from .....conftest import IMAGE_ASSETS, HfRunner, ImageAsset, _ImageAssets
|
||||
from ....utils import check_logprobs_close
|
||||
@@ -110,11 +109,6 @@ class VLMTestInfo(NamedTuple):
|
||||
# Indicates we should explicitly pass the EOS from the tokenizer
|
||||
use_tokenizer_eos: bool = False
|
||||
auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM
|
||||
# Callable to pass to the HF runner to run on inputs; for now, we also pass
|
||||
# the data type to input post processing, because almost all of the uses of
|
||||
# postprocess_inputs are to fix the data types of BatchEncoding values.
|
||||
postprocess_inputs: Callable[[BatchEncoding, str],
|
||||
BatchEncoding] = identity
|
||||
patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]] = None
|
||||
|
||||
# Post processors that if defined, will run oun the outputs of the
|
||||
@@ -130,7 +124,7 @@ class VLMTestInfo(NamedTuple):
|
||||
# is all combinations of .models + all fields below
|
||||
max_tokens: Union[int, tuple[int]] = 128
|
||||
num_logprobs: Union[int, tuple[int]] = 5
|
||||
dtype: Union[str, Iterable[str]] = "half"
|
||||
dtype: Union[str, Union[list[str], tuple[str, ...]]] = "auto"
|
||||
distributed_executor_backend: Optional[Union[str, Iterable[str]]] = None
|
||||
# Only expanded in video tests
|
||||
num_video_frames: Union[int, tuple[int]] = 16
|
||||
@@ -171,7 +165,6 @@ class VLMTestInfo(NamedTuple):
|
||||
"vllm_output_post_proc": self.vllm_output_post_proc,
|
||||
"auto_cls": self.auto_cls,
|
||||
"use_tokenizer_eos": self.use_tokenizer_eos,
|
||||
"postprocess_inputs": self.postprocess_inputs,
|
||||
"comparator": self.comparator,
|
||||
"get_stop_token_ids": self.get_stop_token_ids,
|
||||
"hf_model_kwargs": self.hf_model_kwargs,
|
||||
|
||||
Reference in New Issue
Block a user