[Hotfix][Pixtral] Fix multiple images bugs (#8415)
This commit is contained in:
committed by
GitHub
parent
b61bd98f90
commit
d31174a4e1
@@ -1,4 +1,3 @@
|
||||
import math
|
||||
from array import array
|
||||
from dataclasses import dataclass, fields
|
||||
from itertools import tee
|
||||
@@ -15,11 +14,12 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import merge_multimodal_embeddings
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
@@ -48,23 +48,29 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
|
||||
tokenizer = cached_get_tokenizer(
|
||||
ctx.model_config.tokenizer,
|
||||
tokenizer_mode=ctx.model_config.tokenizer_mode)
|
||||
mm_encoder = tokenizer.instruct.mm_encoder
|
||||
|
||||
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
|
||||
patch_size = mm_encoder.mm_config.image_patch_size
|
||||
image_token_id = mm_encoder.special_ids.img
|
||||
|
||||
mm_config = ctx.model_config.multimodal_config
|
||||
max_num_images_per_request = mm_config.limit_per_prompt.get("image", 1)
|
||||
|
||||
# approximate image size
|
||||
size = int(math.sqrt(seq_len) * mm_encoder.mm_config.image_patch_size)
|
||||
num_images = mm_config.limit_per_prompt.get("image", 1)
|
||||
|
||||
# dummy size
|
||||
size = 256
|
||||
image = Image.new("RGB", (size, size), color=0)
|
||||
img_chunk = ImageChunk(image=image)
|
||||
|
||||
tokens = mm_encoder(img_chunk).tokens
|
||||
token_ids = max_num_images_per_request * array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
tokens)
|
||||
image_feature_size = (size**2) // (patch_size**2)
|
||||
|
||||
num_image_tokens = image_feature_size * num_images
|
||||
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[image_token_id]) * num_image_tokens
|
||||
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[0]) * (seq_len - num_image_tokens)
|
||||
|
||||
seq_data = SequenceData(token_ids)
|
||||
mm_data = {"image": max_num_images_per_request * [image]}
|
||||
mm_data = {"image": num_images * [image]}
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
@@ -99,32 +105,31 @@ def input_mapper_for_pixtral(ctx: InputContext,
|
||||
return MultiModalInputs({"images": images})
|
||||
|
||||
|
||||
def merge_multimodal_embeddings(input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
image_features: Optional[List[torch.Tensor]],
|
||||
image_id: int) -> torch.Tensor:
|
||||
text_locations = input_ids != image_id
|
||||
image_locations = input_ids == image_id
|
||||
def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is not None and "image" in multi_modal_data:
|
||||
tokenizer = cached_get_tokenizer(
|
||||
ctx.model_config.tokenizer,
|
||||
tokenizer_mode=ctx.model_config.tokenizer_mode)
|
||||
|
||||
seq_len = input_ids.shape[0]
|
||||
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
|
||||
image_token_id = mm_encoder.special_ids.img
|
||||
|
||||
N_txt = text_locations.sum().item()
|
||||
_, D_txt = inputs_embeds.shape
|
||||
N_img, D_img = image_features.shape
|
||||
if image_token_id not in llm_inputs['prompt_token_ids']:
|
||||
raise ValueError(
|
||||
(f"You've passed {llm_inputs=} without {image_token_id=}"
|
||||
" Make sure to process your input via mistral_common's"
|
||||
" tokenizer or pass a chat completion request. For more"
|
||||
" For more info, see: "
|
||||
"https://github.com/vllm-project/vllm/issues/8411."))
|
||||
|
||||
assert (D_txt == D_img), (f"Text features dim {D_txt} should be equal "
|
||||
"to image features dim {D_img}")
|
||||
assert (seq_len == N_txt +
|
||||
N_img), (f"seq_len {seq_len} should be equal to N_txt + N_img "
|
||||
f"{(N_txt, N_img, image_locations.sum().item())}")
|
||||
|
||||
inputs_embeds[image_locations, :] = image_features
|
||||
return inputs_embeds
|
||||
return llm_inputs
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral)
|
||||
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
|
||||
def __init__(self,
|
||||
@@ -201,11 +206,21 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
return None
|
||||
|
||||
if isinstance(images, torch.Tensor):
|
||||
# always take last images
|
||||
images = [images[-1][i] for i in range(images.size(1))]
|
||||
# if passed as batch take all images
|
||||
N, B, C, W, H = images.shape
|
||||
images = images.reshape(N * B, C, W, H)
|
||||
images = [images[i] for i in range(images.size(0))]
|
||||
elif isinstance(images, list):
|
||||
# always take last images
|
||||
images = [images[-1][i] for i in range(len(images[0]))]
|
||||
# if passed as list flatten lists of tensors
|
||||
flatten_images = []
|
||||
for imgs_per_req in images:
|
||||
imgs_per_req = [
|
||||
imgs_per_req[i] for i in range(imgs_per_req.size(0))
|
||||
] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req
|
||||
|
||||
flatten_images.extend(imgs_per_req)
|
||||
|
||||
images = flatten_images
|
||||
|
||||
return images
|
||||
|
||||
|
||||
Reference in New Issue
Block a user