diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index 5433805b6..95e7d5d60 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -979,6 +979,13 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
* ✅︎
* ✅︎
+- * `MiniMaxVL01ForConditionalGeneration`
+ * MiniMax-VL
+ * T + IE+
+ * `MiniMaxAI/MiniMax-VL-01`, etc.
+ *
+ * ✅︎
+ * ✅︎
- * `Mistral3ForConditionalGeneration`
* Mistral3
* T + I+
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index b3c56e18b..4dc49d18c 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -270,6 +270,7 @@ def _test_processing_correctness_mistral(
"openbmb/MiniCPM-Llama3-V-2_5",
"openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6",
+ "MiniMaxAI/MiniMax-VL-01",
"allenai/Molmo-7B-D-0924",
"allenai/Molmo-7B-O-0924",
"nvidia/NVLM-D-72B",
diff --git a/tests/models/multimodal/processing/test_minimax_vl_01.py b/tests/models/multimodal/processing/test_minimax_vl_01.py
index d333c32dc..10de28ab5 100644
--- a/tests/models/multimodal/processing/test_minimax_vl_01.py
+++ b/tests/models/multimodal/processing/test_minimax_vl_01.py
@@ -12,7 +12,6 @@ from ...utils import build_model_context
@pytest.mark.parametrize("model_id", ["MiniMaxAI/MiniMax-VL-01"])
-# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(
image_assets: _ImageAssets,
diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py
index 14e105586..4ac60f97b 100644
--- a/vllm/model_executor/models/minimax_vl_01.py
+++ b/vllm/model_executor/models/minimax_vl_01.py
@@ -1,52 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
+from collections.abc import Iterable, Mapping
+from typing import Literal, Optional, Set, Tuple, TypedDict, Union, cast
-from abc import abstractmethod
-from collections.abc import Iterable, Mapping, Sequence
-from dataclasses import dataclass
-from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
- TypeVar, Union, cast)
-
-import numpy as np
import torch
import torch.nn as nn
-from transformers import BatchFeature, CLIPVisionConfig, PretrainedConfig
-from transformers.image_processing_utils import select_best_resolution
+from transformers import BatchFeature
from vllm.config import VllmConfig
from vllm.jsontree import json_map_leaves
-from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
-from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
-from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
- ImageSize, MultiModalDataItems)
-from vllm.multimodal.processing import (BaseMultiModalProcessor,
- BaseProcessingInfo, PromptReplacement,
- PromptUpdate)
-from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.minimax_vl_01 import MiniMaxVL01Config
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
+from .llava import (BaseLlavaMultiModalProcessor, LlavaDummyInputsBuilder,
+ init_vision_tower_for_llava)
+from .llava_next import LlavaNextProcessingInfo
from .pixtral import PixtralHFVisionModel
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
-from .vision import get_vision_encoder_info
-
-logger = init_logger(__name__)
-
-
-# For dummy input only
-@dataclass
-class MaxImageTokenMeta:
- width: int = 1024
- height: int = 1024
class MiniMaxVL01ImagePixelInputs(TypedDict):
@@ -69,66 +49,8 @@ class MiniMaxVL01ImageEmbeddingInputs(TypedDict):
"""
-def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
- if not isinstance(grid_pinpoints, list):
- raise TypeError("grid_pinpoints should be a list of tuples or lists")
-
- # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple,
- # otherwise it will cause wrong calculate
- if not isinstance(image_size, (list, tuple)):
- if not isinstance(image_size, (torch.Tensor, np.ndarray)):
- raise TypeError("image_size invalid type " +
- f"{type(image_size)} with value {image_size}")
- image_size = image_size.tolist()
-
- best_resolution = select_best_resolution(image_size, grid_pinpoints)
- height, width = best_resolution
- num_patches = 0
- # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
- for i in range(0, height, patch_size):
- for j in range(0, width, patch_size):
- num_patches += 1
- # add the base patch
- num_patches += 1
- return num_patches
-
-
-def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
- if not isinstance(grid_pinpoints, list):
- raise TypeError("grid_pinpoints should be a list of tuples or lists")
-
- # ! VERY IMPORTANT if image_size is tensor,
- # must convert to into tuple,
- # otherwise it will cause wrong calculate
- if not isinstance(image_size, (list, tuple)):
- if not isinstance(image_size, (torch.Tensor, np.ndarray)):
- raise TypeError(
- "image_size invalid type " +
- f"{type(image_size)} not valid, " +
- "should be either list, tuple, np.ndarray or tensor")
- image_size = image_size.tolist()
-
- height, width = select_best_resolution(image_size, grid_pinpoints)
- return height // patch_size, width // patch_size
-
-
-def unpad_image(tensor, original_size):
- original_height, original_width = original_size
- current_height, current_width = tensor.shape[1:]
-
- original_aspect_ratio = original_width / original_height
- current_aspect_ratio = current_width / current_height
-
- if original_aspect_ratio > current_aspect_ratio:
- new_height = int(original_height * current_width) // original_width
- padding = (current_height - new_height) // 2
- unpadded_tensor = tensor[:, padding:current_height - padding, :]
- else:
- new_width = int(original_width * current_height) // original_height
- padding = (current_width - new_width) // 2
- unpadded_tensor = tensor[:, :, padding:current_width - padding]
-
- return unpadded_tensor
+MiniMaxVL01ImageInputs = Union[MiniMaxVL01ImagePixelInputs,
+ MiniMaxVL01ImageEmbeddingInputs]
class MiniMaxVL01MultiModalProjector(nn.Module):
@@ -161,144 +83,29 @@ class MiniMaxVL01MultiModalProjector(nn.Module):
return hidden_states
-class MiniMaxVL01LikeConfig(Protocol):
- vision_config: Final[PretrainedConfig]
- image_token_index: Final[int]
- vision_feature_select_strategy: Final[str]
- vision_feature_layer: Final[Union[int, list[int]]]
+class MiniMaxVL01DummyInputsBuilder(LlavaDummyInputsBuilder):
+ pass
-class MiniMaxVL01LikeProcessor(Protocol):
- image_token: Final[str]
-
-
-_I = TypeVar("_I", bound=BaseProcessingInfo)
-
-
-class MiniMaxVL01DummyInputsBuilder(BaseDummyInputsBuilder[_I]):
-
- def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
- num_images = mm_counts.get("image", 0)
- processor = self.info.get_hf_processor()
- image_token = processor.image_token
- return image_token * num_images
-
- def get_dummy_mm_data(
- self,
- seq_len: int,
- mm_counts: Mapping[str, int],
- ) -> MultiModalDataDict:
- num_images = mm_counts.get("image", 0)
-
- return {
- "image":
- self._get_dummy_images(width=MaxImageTokenMeta.width,
- height=MaxImageTokenMeta.height,
- num_images=num_images)
- }
-
-
-class MiniMaxVL01ProcessingInfo(BaseProcessingInfo):
+class MiniMaxVL01ProcessingInfo(LlavaNextProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(MiniMaxVL01Config)
+ def get_hf_processor(self, **kwargs: object):
+ hf_processor = self.ctx.get_hf_processor(**kwargs)
+ image_processor = hf_processor.image_processor
+ image_processor.anyres_preprocess = (
+ image_processor.anyres_for_vllm_preprocess)
+
+ return hf_processor
+
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
- def get_vision_encoder_info(self):
- return get_vision_encoder_info(self.get_hf_config())
-
- def _apply_feature_select_strategy(
- self,
- strategy: str,
- encoder_num_image_tokens: int,
- ) -> int:
- if strategy == "default":
- return encoder_num_image_tokens - 1
- if strategy == "full":
- return encoder_num_image_tokens
-
- msg = f"Unexpected feature select strategy: {strategy!r}"
- raise NotImplementedError(msg)
-
- def get_num_image_tokens(
- self,
- *,
- image_width: int,
- image_height: int,
- ) -> int:
- hf_config = self.get_hf_config()
- vision_encoder_info = self.get_vision_encoder_info()
-
- return self._apply_feature_select_strategy(
- hf_config.vision_feature_select_strategy,
- vision_encoder_info.get_num_image_tokens(
- image_width=image_width,
- image_height=image_height,
- ),
- )
-
- def get_image_size_with_most_features(self) -> ImageSize:
- vision_encoder_info = self.get_vision_encoder_info()
- width = height = vision_encoder_info.get_image_size()
- return ImageSize(width=width, height=height)
-
- def get_max_image_tokens(self) -> int:
- target_width, target_height = self.get_image_size_with_most_features()
-
- return self.get_num_image_tokens(
- image_width=target_width,
- image_height=target_height,
- )
-
-
-class BaseMiniMaxVL01MultiModalProcessor(BaseMultiModalProcessor[_I]):
-
- # Copied from BaseMultiModalProcessor
- @abstractmethod
- def _get_mm_fields_config(
- self,
- hf_inputs: BatchFeature,
- hf_processor_mm_kwargs: Mapping[str, object],
- ) -> Mapping[str, MultiModalFieldConfig]:
- raise NotImplementedError
-
- def _get_prompt_updates(
- self,
- mm_items: MultiModalDataItems,
- hf_processor_mm_kwargs: Mapping[str, object],
- out_mm_kwargs: MultiModalKwargs,
- ) -> Sequence[PromptUpdate]:
- hf_config = self.info.get_hf_config()
- image_token_id = hf_config.image_token_index
-
- def get_replacement(item_idx: int):
- images = mm_items.get_items(
- "image", (ImageEmbeddingItems, ImageProcessorItems))
-
- if isinstance(images, ImageEmbeddingItems):
- num_image_tokens = images.get_feature_size(item_idx)
- else:
- image_size = images.get_image_size(item_idx)
- num_image_tokens = self.info.get_num_image_tokens(
- image_width=image_size.width,
- image_height=image_size.height,
- )
-
- return [image_token_id] * num_image_tokens
-
- return [
- PromptReplacement(
- modality="image",
- target=[image_token_id],
- replacement=get_replacement,
- ),
- ]
-
class MiniMaxVL01MultiModalProcessor(
- BaseMiniMaxVL01MultiModalProcessor[MiniMaxVL01ProcessingInfo]):
+ BaseLlavaMultiModalProcessor[MiniMaxVL01ProcessingInfo]):
def _call_hf_processor(
self,
@@ -314,10 +121,9 @@ class MiniMaxVL01MultiModalProcessor(
pixel_values = processed_outputs.get("pixel_values")
if pixel_values is not None:
+ # Avoid padding since we need the output for each image to be
+ # independent of other images for the cache to work correctly
image_sizes = processed_outputs["image_sizes"]
- min_len = min(len(pixel_values), len(image_sizes))
- pixel_values = pixel_values[:min_len]
- image_sizes = image_sizes[:min_len]
assert len(pixel_values) == len(image_sizes)
processed_outputs["pixel_values"] = [
@@ -337,65 +143,6 @@ class MiniMaxVL01MultiModalProcessor(
}
-def _get_num_hidden_layers(hf_config: MiniMaxVL01LikeConfig) -> int:
- """Determine the number of hidden layers to initialize up to in the
- visual encoder.
-
- Args:
- hf_config: Model config with vision feature layer(s).
- """
- feature_layers = hf_config.vision_feature_layer
- num_hidden_layers = hf_config.vision_config.num_hidden_layers
- # If we have one feature layer, initialize up to that layer
- if isinstance(feature_layers, int):
- return _get_layer_index(feature_layers, num_hidden_layers)
- # If we have multiple feature layers, initialize up to the deepest one
- elif isinstance(feature_layers, (list, tuple)):
- return max(
- _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
- raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
- " is not supported")
-
-
-def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
- """Given a signed vision feature layer, get the number of hidden layers
- needed to leverage it.
-
- Args:
- feature_layer_index: Index of a required layer in the visual encoder.
- num_hidden_layers: The total number of hidden layers in the visual
- encoder.
- """
- if feature_layer_index < 0:
- return num_hidden_layers + feature_layer_index + 1
- return feature_layer_index
-
-
-def init_vision_tower_for_MiniMaxVL01(
- hf_config: MiniMaxVL01LikeConfig,
- quant_config: Optional[QuantizationConfig],
- *,
- require_post_norm: Optional[bool] = None,
- prefix: str = "",
-) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
- vision_config = hf_config.vision_config
-
- # Initialize the vision tower only up to the deepest required feature layer
- num_hidden_layers = _get_num_hidden_layers(hf_config)
-
- if isinstance(vision_config, CLIPVisionConfig):
- return CLIPVisionModel(
- vision_config,
- quant_config=quant_config,
- num_hidden_layers_override=num_hidden_layers,
- require_post_norm=require_post_norm,
- prefix=prefix,
- )
-
- msg = f"Unsupported vision config: {type(vision_config)}"
- raise NotImplementedError(msg)
-
-
@MULTIMODAL_REGISTRY.register_processor(
MiniMaxVL01MultiModalProcessor,
info=MiniMaxVL01ProcessingInfo,
@@ -419,7 +166,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings.
- self.vision_tower = init_vision_tower_for_MiniMaxVL01(
+ self.vision_tower = init_vision_tower_for_llava(
config,
quant_config,
require_post_norm=False,
@@ -476,7 +223,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
def _image_pixels_to_features(
self,
- vision_tower: Union[CLIPVisionModel],
+ vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
+ PixtralHFVisionModel],
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# NOTE: we skip the step to select the vision feature layer since
@@ -496,7 +244,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
def _process_image_pixels(
self,
- inputs: Union[MiniMaxVL01ImagePixelInputs],
+ inputs: MiniMaxVL01ImagePixelInputs,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
assert self.vision_tower is not None
@@ -506,7 +254,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
def _process_image_input(
self,
- image_input: MiniMaxVL01ImagePixelInputs,
+ image_input: MiniMaxVL01ImageInputs,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
if image_input["type"] == "image_embeds":
return image_input["data"]
@@ -539,7 +287,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
return data
def _parse_and_validate_image_input(
- self, **kwargs: object) -> Optional[MiniMaxVL01ImagePixelInputs]:
+ self, **kwargs: object) -> Optional[MiniMaxVL01ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)