[Model] support MiniMax-VL-01 model (#16328)

Signed-off-by: qingjun <qingjun@minimaxi.com>
This commit is contained in:
qscqesze
2025-04-29 12:05:50 +08:00
committed by GitHub
parent 96e06e3cb7
commit cde384cd92
11 changed files with 954 additions and 19 deletions

View File

@@ -0,0 +1,615 @@
# SPDX-License-Identifier: Apache-2.0
from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
TypeVar, Union, cast)
import numpy as np
import torch
import torch.nn as nn
from transformers import BatchFeature, CLIPVisionConfig, PretrainedConfig
from transformers.image_processing_utils import select_best_resolution
from vllm.config import VllmConfig
from vllm.jsontree import json_map_leaves
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.minimax_vl_01 import MiniMaxVL01Config
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .pixtral import PixtralHFVisionModel
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vision_encoder_info
logger = init_logger(__name__)
# For dummy input only
@dataclass
class MaxImageTokenMeta:
width: int = 1024
height: int = 1024
class MiniMaxVL01ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
Note that `height` or `width` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
class MiniMaxVL01ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
if not isinstance(grid_pinpoints, list):
raise TypeError("grid_pinpoints should be a list of tuples or lists")
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple,
# otherwise it will cause wrong calculate
if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise TypeError("image_size invalid type " +
f"{type(image_size)} with value {image_size}")
image_size = image_size.tolist()
best_resolution = select_best_resolution(image_size, grid_pinpoints)
height, width = best_resolution
num_patches = 0
# consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
num_patches += 1
# add the base patch
num_patches += 1
return num_patches
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
if not isinstance(grid_pinpoints, list):
raise TypeError("grid_pinpoints should be a list of tuples or lists")
# ! VERY IMPORTANT if image_size is tensor,
# must convert to into tuple,
# otherwise it will cause wrong calculate
if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise TypeError(
"image_size invalid type " +
f"{type(image_size)} not valid, " +
"should be either list, tuple, np.ndarray or tensor")
image_size = image_size.tolist()
height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size
def unpad_image(tensor, original_size):
original_height, original_width = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
new_height = int(original_height * current_width) // original_width
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding:current_height - padding, :]
else:
new_width = int(original_width * current_height) // original_height
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding:current_width - padding]
return unpadded_tensor
class MiniMaxVL01MultiModalProjector(nn.Module):
def __init__(self,
vision_hidden_size: int,
text_hidden_size: int,
projector_hidden_act: str,
multimodal_projector_bias: bool,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.linear_1 = ColumnParallelLinear(vision_hidden_size,
text_hidden_size,
bias=multimodal_projector_bias,
quant_config=quant_config,
prefix=f"{prefix}.linear_1")
self.act = get_act_fn(projector_hidden_act)
self.linear_2 = RowParallelLinear(text_hidden_size,
text_hidden_size,
bias=multimodal_projector_bias,
quant_config=quant_config,
prefix=f"{prefix}.linear_2")
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.linear_2(hidden_states)
return hidden_states
class MiniMaxVL01LikeConfig(Protocol):
vision_config: Final[PretrainedConfig]
image_token_index: Final[int]
vision_feature_select_strategy: Final[str]
vision_feature_layer: Final[Union[int, list[int]]]
class MiniMaxVL01LikeProcessor(Protocol):
image_token: Final[str]
_I = TypeVar("_I", bound=BaseProcessingInfo)
class MiniMaxVL01DummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
return {
"image":
self._get_dummy_images(width=MaxImageTokenMeta.width,
height=MaxImageTokenMeta.height,
num_images=num_images)
}
class MiniMaxVL01ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(MiniMaxVL01Config)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_vision_encoder_info(self):
return get_vision_encoder_info(self.get_hf_config())
def _apply_feature_select_strategy(
self,
strategy: str,
encoder_num_image_tokens: int,
) -> int:
if strategy == "default":
return encoder_num_image_tokens - 1
if strategy == "full":
return encoder_num_image_tokens
msg = f"Unexpected feature select strategy: {strategy!r}"
raise NotImplementedError(msg)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self.get_hf_config()
vision_encoder_info = self.get_vision_encoder_info()
return self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
),
)
def get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self.get_vision_encoder_info()
width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
class BaseMiniMaxVL01MultiModalProcessor(BaseMultiModalProcessor[_I]):
# Copied from BaseMultiModalProcessor
@abstractmethod
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
raise NotImplementedError
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
def get_replacement(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems))
if isinstance(images, ImageEmbeddingItems):
num_image_tokens = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
)
return [image_token_id] * num_image_tokens
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement,
),
]
class MiniMaxVL01MultiModalProcessor(
BaseMiniMaxVL01MultiModalProcessor[MiniMaxVL01ProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
pixel_values = processed_outputs.get("pixel_values")
if pixel_values is not None:
image_sizes = processed_outputs["image_sizes"]
min_len = min(len(pixel_values), len(image_sizes))
pixel_values = pixel_values[:min_len]
image_sizes = image_sizes[:min_len]
assert len(pixel_values) == len(image_sizes)
processed_outputs["pixel_values"] = [
p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
]
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return {
"pixel_values": MultiModalFieldConfig.batched("image"),
"image_embeds": MultiModalFieldConfig.batched("image"),
}
def _get_num_hidden_layers(hf_config: MiniMaxVL01LikeConfig) -> int:
"""Determine the number of hidden layers to initialize up to in the
visual encoder.
Args:
hf_config: Model config with vision feature layer(s).
"""
feature_layers = hf_config.vision_feature_layer
num_hidden_layers = hf_config.vision_config.num_hidden_layers
# If we have one feature layer, initialize up to that layer
if isinstance(feature_layers, int):
return _get_layer_index(feature_layers, num_hidden_layers)
# If we have multiple feature layers, initialize up to the deepest one
elif isinstance(feature_layers, (list, tuple)):
return max(
_get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
" is not supported")
def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
"""Given a signed vision feature layer, get the number of hidden layers
needed to leverage it.
Args:
feature_layer_index: Index of a required layer in the visual encoder.
num_hidden_layers: The total number of hidden layers in the visual
encoder.
"""
if feature_layer_index < 0:
return num_hidden_layers + feature_layer_index + 1
return feature_layer_index
def init_vision_tower_for_MiniMaxVL01(
hf_config: MiniMaxVL01LikeConfig,
quant_config: Optional[QuantizationConfig],
*,
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the deepest required feature layer
num_hidden_layers = _get_num_hidden_layers(hf_config)
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
prefix=prefix,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_processor(
MiniMaxVL01MultiModalProcessor,
info=MiniMaxVL01ProcessingInfo,
dummy_inputs=MiniMaxVL01DummyInputsBuilder)
class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = init_vision_tower_for_MiniMaxVL01(
config,
quant_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"))
self.multi_modal_projector = MiniMaxVL01MultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=True,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"))
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.vision_feature_layer = config.vision_feature_layer
self.vocab_size = config.text_config.vocab_size
self.pad_token_id = -1
if self.config.pad_token_id is not None:
self.pad_token_id = self.config.pad_token_id
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel],
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values)
def select_features(leaf: torch.Tensor):
return self._select_image_features(
leaf,
strategy=self.config.vision_feature_select_strategy,
)
return cast(
Union[torch.Tensor, tuple[torch.Tensor, ...]],
json_map_leaves(select_features, image_features),
)
def _process_image_pixels(
self,
inputs: Union[MiniMaxVL01ImagePixelInputs],
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"]
return self._image_pixels_to_features(self.vision_tower, pixel_values)
def _process_image_input(
self,
image_input: MiniMaxVL01ImagePixelInputs,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
if isinstance(image_features, torch.Tensor):
return self.multi_modal_projector(image_features)
feature_sizes = [
image_feature.shape[0] for image_feature in image_features
]
image_embeds = self.multi_modal_projector(torch.cat(image_features))
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[MiniMaxVL01ImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return MiniMaxVL01ImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
)
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return MiniMaxVL01ImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
)
raise AssertionError("This line should be unreachable.")
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
return self._process_image_input(image_input)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
inputs_embeds = None
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
hidden_states = self.language_model.model(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)