[VLM] Merged multi-modal processor for LLaVA-NeXT (#11682)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-03 00:39:27 +08:00
committed by GitHub
parent b6087a6bee
commit 8c38ee7007
14 changed files with 609 additions and 555 deletions

View File

@@ -1,6 +1,7 @@
from abc import abstractmethod
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
Tuple, TypedDict, Union)
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, Union)
import torch
import torch.nn as nn
@@ -12,7 +13,6 @@ from transformers.models.pixtral import PixtralProcessor
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
@@ -23,23 +23,23 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import ImageProcessorItems
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement,
InputProcessingContext,
MultiModalDataItems, ProcessingCache,
ProcessorInputs, PromptReplacement,
full_groupby_modality)
from vllm.sequence import IntermediateTensors
from .clip import (CLIPVisionModel, dummy_image_for_clip,
get_max_clip_image_tokens)
from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
get_max_pixtral_hf_image_tokens,
get_pixtral_hf_image_feature_size)
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
get_max_siglip_image_tokens)
from .pixtral import (PixtralHFVisionModel,
get_pixtral_hf_image_feature_grid_size)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import vision_encoder_info
class LlavaImagePixelInputs(TypedDict):
@@ -94,39 +94,167 @@ class LlavaMultiModalProjector(nn.Module):
return hidden_states
def get_max_llava_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
num_image_tokens = get_max_clip_image_tokens(vision_config)
elif isinstance(vision_config, SiglipVisionConfig):
num_image_tokens = get_max_siglip_image_tokens(vision_config)
elif isinstance(vision_config, PixtralVisionConfig):
num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
strategy = hf_config.vision_feature_select_strategy
if strategy == "default":
return num_image_tokens - 1
elif strategy == "full":
return num_image_tokens
else:
raise ValueError(f"Unexpected select feature strategy: {strategy}")
class LlavaLikeConfig(Protocol):
vision_config: Final[PretrainedConfig]
vision_feature_select_strategy: Final[str]
vision_feature_layer: Final[Union[int, List[int]]]
class LlavaMultiModalProcessor(BaseMultiModalProcessor):
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
def __init__(self,
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__(ctx,
cache=cache,
enable_sanity_checks=enable_sanity_checks)
vision_config = self._get_hf_config().vision_config
self._vision_encoder_info = vision_encoder_info(vision_config)
@abstractmethod
def _get_hf_config(self) -> LlavaLikeConfig:
raise NotImplementedError
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return {"image": get_max_llava_image_tokens(self.ctx)}
def _apply_feature_select_strategy(
self,
strategy: str,
encoder_num_image_tokens: int,
) -> int:
if strategy == "default":
return encoder_num_image_tokens - 1
if strategy == "full":
return encoder_num_image_tokens
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor))
msg = f"Unexpected feature select strategy: {strategy!r}"
raise NotImplementedError(msg)
def _get_max_image_tokens(self) -> int:
hf_config = self._get_hf_config()
return self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
self._vision_encoder_info.get_max_image_tokens(),
)
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return {"image": self._get_max_image_tokens()}
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_dummy_image_size(self) -> ImageSize:
image_size = self._vision_encoder_info.get_image_size()
return ImageSize(image_size, image_size)
@abstractmethod
def _get_image_token(self) -> str:
raise NotImplementedError
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
image_token = self._get_image_token()
target_width, target_height = self._get_dummy_image_size()
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
def _get_hf_config(self) -> LlavaConfig:
return self.ctx.get_hf_config(LlavaConfig)
def _get_hf_processor(self) -> LlavaProcessor:
return self.ctx.get_hf_processor(LlavaProcessor)
def _get_image_token(self) -> str:
return self._get_hf_processor().image_token
def _get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
return self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
self._vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
),
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
image_token_id = hf_config.image_token_index
def get_replacement(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems))
if isinstance(images, ImageEmbeddingItems):
num_image_tokens = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
num_image_tokens = self._get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
)
return [image_token_id] * num_image_tokens
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement,
),
]
class PixtralHFMultiModalProcessor(BaseLlavaMultiModalProcessor):
def _get_hf_config(self) -> LlavaConfig:
return self.ctx.get_hf_config(LlavaConfig)
def _get_hf_processor(self) -> PixtralProcessor:
return self.ctx.get_hf_processor(PixtralProcessor)
def _get_image_token(self) -> str:
return self._get_hf_processor().image_token
def _call_hf_processor(
self,
@@ -140,119 +268,82 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
mm_kwargs=mm_kwargs,
)
# NOTE: pixel_values=None for MLlavaProcessor
pixel_values = processed_outputs.get("pixel_values")
if pixel_values is not None:
images = mm_data["images"]
assert isinstance(images, list)
if isinstance(self._get_hf_processor(), PixtralProcessor):
# Original output: (1, num_images, C, H, W)
# New output: (num_images, C, H, W)
assert (isinstance(pixel_values, list)
and len(pixel_values) == 1)
assert (isinstance(pixel_values[0], list)
and len(pixel_values[0]) == len(images))
# Original output: (1, num_images, C, H, W)
# New output: (num_images, C, H, W)
assert (isinstance(pixel_values, list) and len(pixel_values) == 1)
assert (isinstance(pixel_values[0], list)
and len(pixel_values[0]) == len(images))
processed_outputs["pixel_values"] = pixel_values[0]
processed_outputs["pixel_values"] = pixel_values[0]
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(LlavaConfig)
hf_config = self._get_hf_config()
image_token_id = hf_config.image_token_index
processor = self._get_hf_processor()
if isinstance(processor, PixtralProcessor):
image_token = processor.image_token
image_break_token = processor.image_break_token
image_end_token = processor.image_end_token
image_token = processor.image_token
image_break_token = processor.image_break_token
image_end_token = processor.image_end_token
vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
def get_replacement_pixtral(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
def get_replacement(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
(
num_width_tokens,
num_height_tokens,
) = get_pixtral_hf_image_feature_size(
vision_config,
image_width=image_size.width,
image_height=image_size.height,
)
ncols, nrows = get_pixtral_hf_image_feature_grid_size(
vision_config,
image_width=image_size.width,
image_height=image_size.height,
)
tokens = ([image_token] * num_width_tokens +
[image_break_token]) * num_height_tokens
tokens[-1] = image_end_token
tokens = ([image_token] * ncols + [image_break_token]) * nrows
tokens[-1] = image_end_token
return "".join(tokens)
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement_pixtral,
),
]
max_image_tokens = get_max_llava_image_tokens(self.ctx)
return "".join(tokens)
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=[image_token_id] * max_image_tokens,
)
replacement=get_replacement,
),
]
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts.get("image", 0)
if isinstance(vision_config, CLIPVisionConfig):
data = dummy_image_for_clip(vision_config, num_images)
elif isinstance(vision_config, SiglipVisionConfig):
data = dummy_image_for_siglip(vision_config, num_images)
elif isinstance(vision_config, PixtralVisionConfig):
data = dummy_image_for_pixtral_hf(vision_config, num_images)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def _build_llava_or_pixtral_hf_processor(
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True,
) -> BaseLlavaMultiModalProcessor:
hf_config = ctx.get_hf_config(LlavaConfig)
hf_processor = self._get_hf_processor()
image_token = hf_processor.image_token
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=data,
if isinstance(hf_config.vision_config, PixtralVisionConfig):
return PixtralHFMultiModalProcessor(
ctx,
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)
class LlavaLikeConfig(Protocol):
vision_config: PretrainedConfig
vision_feature_layer: Union[int, List[int]]
return LlavaMultiModalProcessor(
ctx,
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)
def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
@@ -330,7 +421,7 @@ def init_vision_tower_for_llava(
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
@@ -596,7 +687,12 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
) -> MultiModalInputsV2:
hf_config = self.ctx.get_hf_config(LlavaConfig)
image_token_id = hf_config.image_token_index
max_image_tokens = get_max_llava_image_tokens(self.ctx)
# Assume that it doesn't depend on the image size
num_image_tokens = self._get_num_image_tokens(
image_width=-1,
image_height=-1,
)
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
@@ -609,14 +705,14 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
def get_replacement_mantis(item_idx: int):
return "".join([
f"(image {item_idx+1}: <Image>", # 7 tokens
"<image>" * max_image_tokens,
"<image>" * num_image_tokens,
"</Image>)", # 3 tokens
])
mantis_repls = self._bind_prompt_replacements([
PromptReplacement(
modality="image",
target=[image_token_id] * max_image_tokens,
target=[image_token_id] * num_image_tokens,
replacement=get_replacement_mantis,
)
])