[Model] Support Pixtral models in the HF Transformers format (#9036)
This commit is contained in:
@@ -5,7 +5,8 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
|
||||
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
|
||||
SiglipVisionConfig)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
@@ -22,6 +23,10 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||
dummy_seq_data_for_clip, get_max_clip_image_tokens,
|
||||
input_processor_for_clip)
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
|
||||
dummy_seq_data_for_pixtral_hf,
|
||||
get_max_pixtral_hf_image_tokens,
|
||||
input_processor_for_pixtral_hf)
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
|
||||
input_processor_for_siglip)
|
||||
@@ -31,8 +36,13 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
|
||||
class LlavaImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size * num_images, num_channels, height, width)`
|
||||
|
||||
Note that `height` or `width` may be different per batch and image,
|
||||
in which case the data is passed as a list instead of a batched tensor.
|
||||
"""
|
||||
|
||||
|
||||
class LlavaImageEmbeddingInputs(TypedDict):
|
||||
@@ -77,6 +87,8 @@ def get_max_llava_image_tokens(ctx: InputContext):
|
||||
num_image_tokens = get_max_clip_image_tokens(vision_config)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
num_image_tokens = get_max_siglip_image_tokens(vision_config)
|
||||
elif isinstance(vision_config, PixtralVisionConfig):
|
||||
num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
|
||||
else:
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
@@ -120,6 +132,17 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
|
||||
|
||||
mm_data = dummy_image_for_siglip(vision_config, num_images)
|
||||
return seq_data, mm_data
|
||||
elif isinstance(vision_config, PixtralVisionConfig):
|
||||
seq_data = dummy_seq_data_for_pixtral_hf(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
|
||||
return seq_data, mm_data
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
@@ -163,6 +186,15 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
elif isinstance(vision_config, PixtralVisionConfig):
|
||||
# We ignore image_feature_size_override since we have non-uniform
|
||||
# image sizes for Pixtral
|
||||
return input_processor_for_pixtral_hf(
|
||||
model_config,
|
||||
vision_config,
|
||||
inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
@@ -189,6 +221,9 @@ def _init_vision_tower(hf_config: LlavaConfig):
|
||||
vision_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
)
|
||||
elif isinstance(vision_config, PixtralVisionConfig):
|
||||
# TODO: allow layer override?
|
||||
return PixtralHFVisionModel(vision_config)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
@@ -210,6 +245,15 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
# NOTE: These are special cases for Pixtral-12B in the HF-format
|
||||
# https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa
|
||||
if (config.text_config.architectures is None
|
||||
and config.text_config.model_type == "mistral"):
|
||||
config.text_config.architectures = ["MistralForCausalLM"]
|
||||
if (config.projector_hidden_act is None
|
||||
and config.vision_config.hidden_act == "gelu"):
|
||||
config.projector_hidden_act = "gelu"
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_tower = _init_vision_tower(config)
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||
@@ -246,6 +290,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[LlavaImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_sizes = kwargs.pop("image_sizes", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is None and image_embeds is None:
|
||||
@@ -256,6 +301,26 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
# Case for models like PixtralHF that have dynamic image sizes
|
||||
# so we need to produce a list of tensors
|
||||
if image_sizes is not None:
|
||||
images = pixel_values
|
||||
if isinstance(images, torch.Tensor):
|
||||
# if passed as batch take all images
|
||||
NN, N, B, C, W, H = images.shape
|
||||
images = images.reshape(NN * N * B, C, W, H)
|
||||
images = [images[i] for i in range(images.size(0))]
|
||||
elif isinstance(images, list):
|
||||
# if passed as list flatten lists of tensors
|
||||
while isinstance(images, list) and len(images) == 1:
|
||||
images = images[0]
|
||||
|
||||
# TODO: Add validation based on image_sizes
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=images,
|
||||
)
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(
|
||||
@@ -286,7 +351,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def _image_pixels_to_features(
|
||||
self,
|
||||
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
|
||||
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
|
||||
PixtralHFVisionModel],
|
||||
pixel_values: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user