[Bugfix][V1] Fix molmo text-only inputs (#11676)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -5,17 +5,20 @@ typically specific to a small subset of models.
|
||||
import re
|
||||
import types
|
||||
from pathlib import PosixPath
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from transformers import AutoConfig, AutoTokenizer, BatchEncoding
|
||||
from transformers import (AutoConfig, AutoTokenizer, BatchEncoding,
|
||||
GenerationConfig)
|
||||
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.transformers_utils.tokenizer import patch_padding_side
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
from .....conftest import HfRunner, ImageAsset, _ImageAssets
|
||||
from .....conftest import (HfRunner, ImageAsset, PromptAudioInput,
|
||||
PromptImageInput, PromptVideoInput, _ImageAssets)
|
||||
from ....utils import TokensTextLogprobs
|
||||
from .types import RunnerOutput
|
||||
|
||||
|
||||
@@ -222,6 +225,11 @@ def wrap_inputs_post_processor(hf_inputs: BatchEncoding, dtype: str):
|
||||
return {"model_inputs": hf_inputs}
|
||||
|
||||
|
||||
def molmo_post_processor(hf_inputs: BatchEncoding, dtype: str):
|
||||
hf_inputs = cast_dtype_post_processor("images")(hf_inputs, dtype)
|
||||
return {k: v.unsqueeze(0) for k, v in hf_inputs.items()}
|
||||
|
||||
|
||||
####### Prompt path encoders for models that need models on disk
|
||||
def qwen_prompt_path_encoder(
|
||||
tmp_path: PosixPath, prompt: str, assets: Union[List[ImageAsset],
|
||||
@@ -451,3 +459,88 @@ def mantis_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
|
||||
|
||||
return hf_model
|
||||
|
||||
|
||||
def _generate_greedy_logprobs_limit(
|
||||
self,
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[TokensTextLogprobs]:
|
||||
all_inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
|
||||
# Process in batches for inference.
|
||||
if len(all_inputs):
|
||||
input_ids_lst = []
|
||||
images_lst = []
|
||||
images_input_idx_lst = []
|
||||
imges_masks_lst = []
|
||||
for inputs in all_inputs:
|
||||
input_ids_lst.append(inputs["input_ids"])
|
||||
images_lst.append(inputs["images"])
|
||||
images_input_idx_lst.append(inputs["image_input_idx"])
|
||||
imges_masks_lst.append(inputs["image_masks"])
|
||||
batch_inputs = {}
|
||||
batch_inputs['input_ids'] = torch.cat(input_ids_lst, dim=0)
|
||||
batch_inputs['images'] = torch.cat(images_lst, dim=0)
|
||||
batch_inputs['image_input_idx'] = torch.cat(images_input_idx_lst,
|
||||
dim=0)
|
||||
batch_inputs['image_masks'] = torch.cat(imges_masks_lst, dim=0)
|
||||
|
||||
outputs = self.model.generate_from_batch(
|
||||
batch=self.wrap_device(batch_inputs,
|
||||
device=self.model.device.type),
|
||||
generation_config=GenerationConfig(
|
||||
max_new_tokens=max_tokens,
|
||||
stop_strings="<|endoftext|>",
|
||||
do_sample=False,
|
||||
),
|
||||
tokenizer=self.tokenizer,
|
||||
output_hidden_states=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
all_logprobs: List[List[Dict[int, float]]] = []
|
||||
all_output_ids: List[List[int]] = []
|
||||
all_output_strs: List[str] = []
|
||||
|
||||
for index in range(len(all_inputs)):
|
||||
(
|
||||
seq_logprobs_lst,
|
||||
output_len,
|
||||
) = self._hidden_states_to_logprobs(outputs.hidden_states,
|
||||
num_logprobs)
|
||||
all_logprobs.append(seq_logprobs_lst)
|
||||
seq_ids = outputs.sequences[index]
|
||||
output_ids = seq_ids[-output_len:]
|
||||
all_output_ids.append(output_ids.tolist())
|
||||
all_output_strs.append(self.tokenizer.decode(output_ids))
|
||||
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
|
||||
return [(output_ids, output_str, output_logprobs)
|
||||
for output_ids, output_str, output_logprobs in outputs]
|
||||
|
||||
|
||||
####### Molmo-specific HuggingFace runner patchers
|
||||
def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
"""Patches and returns an instance of the HfRunner to use for Molmo."""
|
||||
hf_processor = hf_model.processor
|
||||
|
||||
def _processor(*args, **kwargs):
|
||||
return hf_processor.process(*args, **kwargs)
|
||||
|
||||
hf_model.processor = _processor
|
||||
|
||||
setattr( # noqa: B010
|
||||
hf_model,
|
||||
"generate_greedy_logprobs_limit",
|
||||
types.MethodType(_generate_greedy_logprobs_limit, hf_model),
|
||||
)
|
||||
|
||||
return hf_model
|
||||
|
||||
Reference in New Issue
Block a user