[Model] Support Pixtral models in the HF Transformers format (#9036)

This commit is contained in:
Michael Goin
2024-10-18 15:29:56 -04:00
committed by GitHub
parent 67a7e5ef38
commit 3921a2f29e
7 changed files with 503 additions and 12 deletions

View File

@@ -5,7 +5,8 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
SiglipVisionConfig)
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
@@ -22,6 +23,10 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_max_clip_image_tokens,
input_processor_for_clip)
from .interfaces import SupportsMultiModal, SupportsPP
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
dummy_seq_data_for_pixtral_hf,
get_max_pixtral_hf_image_tokens,
input_processor_for_pixtral_hf)
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip)
@@ -31,8 +36,13 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
Note that `height` or `width` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
class LlavaImageEmbeddingInputs(TypedDict):
@@ -77,6 +87,8 @@ def get_max_llava_image_tokens(ctx: InputContext):
num_image_tokens = get_max_clip_image_tokens(vision_config)
elif isinstance(vision_config, SiglipVisionConfig):
num_image_tokens = get_max_siglip_image_tokens(vision_config)
elif isinstance(vision_config, PixtralVisionConfig):
num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@@ -120,6 +132,17 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
mm_data = dummy_image_for_siglip(vision_config, num_images)
return seq_data, mm_data
elif isinstance(vision_config, PixtralVisionConfig):
seq_data = dummy_seq_data_for_pixtral_hf(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@@ -163,6 +186,15 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
elif isinstance(vision_config, PixtralVisionConfig):
# We ignore image_feature_size_override since we have non-uniform
# image sizes for Pixtral
return input_processor_for_pixtral_hf(
model_config,
vision_config,
inputs,
image_token_id=hf_config.image_token_index,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@@ -189,6 +221,9 @@ def _init_vision_tower(hf_config: LlavaConfig):
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
elif isinstance(vision_config, PixtralVisionConfig):
# TODO: allow layer override?
return PixtralHFVisionModel(vision_config)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@@ -210,6 +245,15 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.config = config
self.multimodal_config = multimodal_config
# NOTE: These are special cases for Pixtral-12B in the HF-format
# https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa
if (config.text_config.architectures is None
and config.text_config.model_type == "mistral"):
config.text_config.architectures = ["MistralForCausalLM"]
if (config.projector_hidden_act is None
and config.vision_config.hidden_act == "gelu"):
config.projector_hidden_act = "gelu"
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = _init_vision_tower(config)
self.multi_modal_projector = LlavaMultiModalProjector(
@@ -246,6 +290,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
@@ -256,6 +301,26 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
# Case for models like PixtralHF that have dynamic image sizes
# so we need to produce a list of tensors
if image_sizes is not None:
images = pixel_values
if isinstance(images, torch.Tensor):
# if passed as batch take all images
NN, N, B, C, W, H = images.shape
images = images.reshape(NN * N * B, C, W, H)
images = [images[i] for i in range(images.size(0))]
elif isinstance(images, list):
# if passed as list flatten lists of tensors
while isinstance(images, list) and len(images) == 1:
images = images[0]
# TODO: Add validation based on image_sizes
return LlavaImagePixelInputs(
type="pixel_values",
data=images,
)
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
@@ -286,7 +351,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
PixtralHFVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor: