[Model] Update multi-modal processor to support Mantis(LLaVA) model (#10711)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-08 01:10:05 +08:00
committed by GitHub
parent 1c768fe537
commit 39e227c7ae
14 changed files with 175 additions and 78 deletions

View File

@@ -3,9 +3,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import torch
from PIL.Image import Image
from transformers import AutoTokenizer, BatchEncoding
from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizerBase
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm.config import TaskOption
from .....conftest import HfRunner, VllmRunner
from .types import RunnerOutput
@@ -28,13 +30,15 @@ def run_test(
use_tokenizer_eos: bool,
postprocess_inputs: Callable[[BatchEncoding], BatchEncoding],
comparator: Callable[..., None],
get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]],
get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase],
List[int]]],
stop_str: Optional[List[str]],
tokenizer_mode: str,
limit_mm_per_prompt: Dict[str, int],
model_kwargs: Optional[Dict[str, Any]],
vllm_runner_kwargs: Optional[Dict[str, Any]],
hf_model_kwargs: Optional[Dict[str, Any]],
patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]],
task: str = "auto",
task: TaskOption = "auto",
runner_mm_key: str = "images",
distributed_executor_backend: Optional[str] = None,
tensor_parallel_size: int = 1,
@@ -58,6 +62,9 @@ def run_test(
if stop_str:
vllm_kwargs["stop"] = stop_str
if vllm_runner_kwargs is None:
vllm_runner_kwargs = {}
with vllm_runner(model,
tokenizer_mode=tokenizer_mode,
max_model_len=max_model_len,
@@ -67,7 +74,8 @@ def run_test(
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=enforce_eager,
task=task) as vllm_model:
task=task,
**vllm_runner_kwargs) as vllm_model:
for prompts, media in vllm_inputs:
vllm_kwargs[runner_mm_key] = media
vllm_output = vllm_model.generate_greedy_logprobs(
@@ -78,7 +86,7 @@ def run_test(
dtype=dtype,
auto_cls=auto_cls,
postprocess_inputs=postprocess_inputs,
model_kwargs=model_kwargs)
model_kwargs=hf_model_kwargs)
# Some models need to patch things like the model processor, e.g., internvl
if patch_hf_runner is not None:

View File

@@ -126,6 +126,16 @@ def llava_onevision_vllm_to_hf_output(vllm_output: RunnerOutput,
return hf_output_ids, hf_output_str, out_logprobs
def mantis_vllm_to_hf_output(vllm_output: RunnerOutput,
model: str) -> RunnerOutput:
"""Sanitize vllm output [mantis] to compare with hf output."""
output_ids, output_str, out_logprobs = vllm_output
hf_output_str = output_str + "<|eot_id|>"
return output_ids, hf_output_str, out_logprobs
def phi3v_vllm_to_hf_output(vllm_output: RunnerOutput,
model: str) -> RunnerOutput:
"""Sanitize vllm output [phi3v] to be comparable with hf output."""
@@ -184,7 +194,7 @@ def get_llava_embeddings(image_assets: _ImageAssets):
####### postprocessors to run on HF BatchEncoding
def get_key_type_post_processor(
def cast_dtype_post_processor(
hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]:
"""Gets a handle to a post processor which converts a given key into a
target data type."""
@@ -418,3 +428,26 @@ def _internvl_generate(
)
return outputs
def mantis_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
from mantis.models.mllava import MLlavaProcessor
hf_model.processor = MLlavaProcessor.from_pretrained(hf_model.model_name)
orig_generate = hf_model.model.generate
tokenizer = hf_model.processor.tokenizer
def _generate(self, *args, **kwargs):
return orig_generate(
*args,
**kwargs,
eos_token_id=[
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
],
)
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
return hf_model

View File

@@ -7,9 +7,11 @@ from typing import (Any, Callable, Dict, Iterable, List, NamedTuple, Optional,
import torch
from PIL.Image import Image
from pytest import MarkDecorator
from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding
from transformers import (AutoModelForCausalLM, BatchEncoding,
PreTrainedTokenizerBase)
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm.config import TaskOption
from vllm.sequence import SampleLogprobs
from vllm.utils import identity
@@ -66,7 +68,7 @@ class ImageSizeWrapper(NamedTuple):
class VLMTestInfo(NamedTuple):
"""Holds the configuration for 1+ tests for one model architecture."""
models: Union[List[str]]
models: List[str]
test_type: Union[VLMTestType, Iterable[VLMTestType]]
# Should be None only if this is a CUSTOM_INPUTS test
@@ -92,18 +94,20 @@ class VLMTestInfo(NamedTuple):
enforce_eager: bool = True
max_model_len: int = 1024
max_num_seqs: int = 256
task: str = "auto"
task: TaskOption = "auto"
tensor_parallel_size: int = 1
vllm_runner_kwargs: Optional[Dict[str, Any]] = None
# Optional callable which gets a list of token IDs from the model tokenizer
get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]] = None
get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase],
List[int]]] = None
# Optional list of strings to stop generation, useful when stop tokens are
# not special tokens in the tokenizer
stop_str: Optional[List[str]] = None
# Exposed options for HF runner
model_kwargs: Optional[Dict[str, Any]] = None
# Indicates we should explicitly pass the EOS from the tokeniezr
hf_model_kwargs: Optional[Dict[str, Any]] = None
# Indicates we should explicitly pass the EOS from the tokenizer
use_tokenizer_eos: bool = False
auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM
# Callable to pass to the HF runner to run on inputs; for now, we also pass
@@ -164,6 +168,7 @@ class VLMTestInfo(NamedTuple):
"max_num_seqs": self.max_num_seqs,
"task": self.task,
"tensor_parallel_size": self.tensor_parallel_size,
"vllm_runner_kwargs": self.vllm_runner_kwargs,
"hf_output_post_proc": self.hf_output_post_proc,
"vllm_output_post_proc": self.vllm_output_post_proc,
"auto_cls": self.auto_cls,
@@ -171,8 +176,8 @@ class VLMTestInfo(NamedTuple):
"postprocess_inputs": self.postprocess_inputs,
"comparator": self.comparator,
"get_stop_token_ids": self.get_stop_token_ids,
"hf_model_kwargs": self.hf_model_kwargs,
"stop_str": self.stop_str,
"model_kwargs": self.model_kwargs,
"patch_hf_runner": self.patch_hf_runner,
"tokenizer_mode": self.tokenizer_mode
}