[Hotfix][Pixtral] Fix multiple images bugs (#8415)

This commit is contained in:
Patrick von Platen
2024-09-13 00:21:51 +02:00
committed by GitHub
parent b61bd98f90
commit d31174a4e1
5 changed files with 197 additions and 78 deletions

View File

@@ -1,4 +1,3 @@
import math
from array import array
from dataclasses import dataclass, fields
from itertools import tee
@@ -15,11 +14,12 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalMask
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import merge_multimodal_embeddings
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
@@ -48,23 +48,29 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
tokenizer_mode=ctx.model_config.tokenizer_mode)
mm_encoder = tokenizer.instruct.mm_encoder
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
patch_size = mm_encoder.mm_config.image_patch_size
image_token_id = mm_encoder.special_ids.img
mm_config = ctx.model_config.multimodal_config
max_num_images_per_request = mm_config.limit_per_prompt.get("image", 1)
# approximate image size
size = int(math.sqrt(seq_len) * mm_encoder.mm_config.image_patch_size)
num_images = mm_config.limit_per_prompt.get("image", 1)
# dummy size
size = 256
image = Image.new("RGB", (size, size), color=0)
img_chunk = ImageChunk(image=image)
tokens = mm_encoder(img_chunk).tokens
token_ids = max_num_images_per_request * array(VLLM_TOKEN_ID_ARRAY_TYPE,
tokens)
image_feature_size = (size**2) // (patch_size**2)
num_image_tokens = image_feature_size * num_images
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * num_image_tokens
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - num_image_tokens)
seq_data = SequenceData(token_ids)
mm_data = {"image": max_num_images_per_request * [image]}
mm_data = {"image": num_images * [image]}
return seq_data, mm_data
@@ -99,32 +105,31 @@ def input_mapper_for_pixtral(ctx: InputContext,
return MultiModalInputs({"images": images})
def merge_multimodal_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: Optional[List[torch.Tensor]],
image_id: int) -> torch.Tensor:
text_locations = input_ids != image_id
image_locations = input_ids == image_id
def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is not None and "image" in multi_modal_data:
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
tokenizer_mode=ctx.model_config.tokenizer_mode)
seq_len = input_ids.shape[0]
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
image_token_id = mm_encoder.special_ids.img
N_txt = text_locations.sum().item()
_, D_txt = inputs_embeds.shape
N_img, D_img = image_features.shape
if image_token_id not in llm_inputs['prompt_token_ids']:
raise ValueError(
(f"You've passed {llm_inputs=} without {image_token_id=}"
" Make sure to process your input via mistral_common's"
" tokenizer or pass a chat completion request. For more"
" For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411."))
assert (D_txt == D_img), (f"Text features dim {D_txt} should be equal "
"to image features dim {D_img}")
assert (seq_len == N_txt +
N_img), (f"seq_len {seq_len} should be equal to N_txt + N_img "
f"{(N_txt, N_img, image_locations.sum().item())}")
inputs_embeds[image_locations, :] = image_features
return inputs_embeds
return llm_inputs
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral)
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
@@ -201,11 +206,21 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
return None
if isinstance(images, torch.Tensor):
# always take last images
images = [images[-1][i] for i in range(images.size(1))]
# if passed as batch take all images
N, B, C, W, H = images.shape
images = images.reshape(N * B, C, W, H)
images = [images[i] for i in range(images.size(0))]
elif isinstance(images, list):
# always take last images
images = [images[-1][i] for i in range(len(images[0]))]
# if passed as list flatten lists of tensors
flatten_images = []
for imgs_per_req in images:
imgs_per_req = [
imgs_per_req[i] for i in range(imgs_per_req.size(0))
] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req
flatten_images.extend(imgs_per_req)
images = flatten_images
return images