[VLM] Merged multi-modal processor for LLaVA-NeXT (#11682)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from abc import abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
|
||||
Protocol, Set, Tuple, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -12,7 +13,6 @@ from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@@ -23,23 +23,23 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import ImageProcessorItems
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement,
|
||||
InputProcessingContext,
|
||||
MultiModalDataItems, ProcessingCache,
|
||||
ProcessorInputs, PromptReplacement,
|
||||
full_groupby_modality)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||
get_max_clip_image_tokens)
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
|
||||
get_max_pixtral_hf_image_tokens,
|
||||
get_pixtral_hf_image_feature_size)
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
get_max_siglip_image_tokens)
|
||||
from .pixtral import (PixtralHFVisionModel,
|
||||
get_pixtral_hf_image_feature_grid_size)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import vision_encoder_info
|
||||
|
||||
|
||||
class LlavaImagePixelInputs(TypedDict):
|
||||
@@ -94,39 +94,167 @@ class LlavaMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def get_max_llava_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
num_image_tokens = get_max_clip_image_tokens(vision_config)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
num_image_tokens = get_max_siglip_image_tokens(vision_config)
|
||||
elif isinstance(vision_config, PixtralVisionConfig):
|
||||
num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
|
||||
else:
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
strategy = hf_config.vision_feature_select_strategy
|
||||
if strategy == "default":
|
||||
return num_image_tokens - 1
|
||||
elif strategy == "full":
|
||||
return num_image_tokens
|
||||
else:
|
||||
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||
class LlavaLikeConfig(Protocol):
|
||||
vision_config: Final[PretrainedConfig]
|
||||
vision_feature_select_strategy: Final[str]
|
||||
vision_feature_layer: Final[Union[int, List[int]]]
|
||||
|
||||
|
||||
class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def __init__(self,
|
||||
ctx: InputProcessingContext,
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
enable_sanity_checks: bool = True) -> None:
|
||||
super().__init__(ctx,
|
||||
cache=cache,
|
||||
enable_sanity_checks=enable_sanity_checks)
|
||||
|
||||
vision_config = self._get_hf_config().vision_config
|
||||
self._vision_encoder_info = vision_encoder_info(vision_config)
|
||||
|
||||
@abstractmethod
|
||||
def _get_hf_config(self) -> LlavaLikeConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
return {"image": get_max_llava_image_tokens(self.ctx)}
|
||||
def _apply_feature_select_strategy(
|
||||
self,
|
||||
strategy: str,
|
||||
encoder_num_image_tokens: int,
|
||||
) -> int:
|
||||
if strategy == "default":
|
||||
return encoder_num_image_tokens - 1
|
||||
if strategy == "full":
|
||||
return encoder_num_image_tokens
|
||||
|
||||
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
|
||||
return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor))
|
||||
msg = f"Unexpected feature select strategy: {strategy!r}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def _get_max_image_tokens(self) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
|
||||
return self._apply_feature_select_strategy(
|
||||
hf_config.vision_feature_select_strategy,
|
||||
self._vision_encoder_info.get_max_image_tokens(),
|
||||
)
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
return {"image": self._get_max_image_tokens()}
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_dummy_image_size(self) -> ImageSize:
|
||||
image_size = self._vision_encoder_info.get_image_size()
|
||||
return ImageSize(image_size, image_size)
|
||||
|
||||
@abstractmethod
|
||||
def _get_image_token(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_token = self._get_image_token()
|
||||
target_width, target_height = self._get_dummy_image_size()
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
|
||||
|
||||
def _get_hf_config(self) -> LlavaConfig:
|
||||
return self.ctx.get_hf_config(LlavaConfig)
|
||||
|
||||
def _get_hf_processor(self) -> LlavaProcessor:
|
||||
return self.ctx.get_hf_processor(LlavaProcessor)
|
||||
|
||||
def _get_image_token(self) -> str:
|
||||
return self._get_hf_processor().image_token
|
||||
|
||||
def _get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
|
||||
return self._apply_feature_select_strategy(
|
||||
hf_config.vision_feature_select_strategy,
|
||||
self._vision_encoder_info.get_num_image_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self._get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
images = mm_items.get_items(
|
||||
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
||||
|
||||
if isinstance(images, ImageEmbeddingItems):
|
||||
num_image_tokens = images.get_feature_size(item_idx)
|
||||
else:
|
||||
image_size = images.get_image_size(item_idx)
|
||||
num_image_tokens = self._get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
return [image_token_id] * num_image_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=get_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class PixtralHFMultiModalProcessor(BaseLlavaMultiModalProcessor):
|
||||
|
||||
def _get_hf_config(self) -> LlavaConfig:
|
||||
return self.ctx.get_hf_config(LlavaConfig)
|
||||
|
||||
def _get_hf_processor(self) -> PixtralProcessor:
|
||||
return self.ctx.get_hf_processor(PixtralProcessor)
|
||||
|
||||
def _get_image_token(self) -> str:
|
||||
return self._get_hf_processor().image_token
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
@@ -140,119 +268,82 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
# NOTE: pixel_values=None for MLlavaProcessor
|
||||
pixel_values = processed_outputs.get("pixel_values")
|
||||
if pixel_values is not None:
|
||||
images = mm_data["images"]
|
||||
assert isinstance(images, list)
|
||||
|
||||
if isinstance(self._get_hf_processor(), PixtralProcessor):
|
||||
# Original output: (1, num_images, C, H, W)
|
||||
# New output: (num_images, C, H, W)
|
||||
assert (isinstance(pixel_values, list)
|
||||
and len(pixel_values) == 1)
|
||||
assert (isinstance(pixel_values[0], list)
|
||||
and len(pixel_values[0]) == len(images))
|
||||
# Original output: (1, num_images, C, H, W)
|
||||
# New output: (num_images, C, H, W)
|
||||
assert (isinstance(pixel_values, list) and len(pixel_values) == 1)
|
||||
assert (isinstance(pixel_values[0], list)
|
||||
and len(pixel_values[0]) == len(images))
|
||||
|
||||
processed_outputs["pixel_values"] = pixel_values[0]
|
||||
processed_outputs["pixel_values"] = pixel_values[0]
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
hf_config = self._get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
processor = self._get_hf_processor()
|
||||
if isinstance(processor, PixtralProcessor):
|
||||
image_token = processor.image_token
|
||||
image_break_token = processor.image_break_token
|
||||
image_end_token = processor.image_end_token
|
||||
image_token = processor.image_token
|
||||
image_break_token = processor.image_break_token
|
||||
image_end_token = processor.image_end_token
|
||||
|
||||
vision_config = hf_config.vision_config
|
||||
assert isinstance(vision_config, PixtralVisionConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
assert isinstance(vision_config, PixtralVisionConfig)
|
||||
|
||||
def get_replacement_pixtral(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
def get_replacement(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
(
|
||||
num_width_tokens,
|
||||
num_height_tokens,
|
||||
) = get_pixtral_hf_image_feature_size(
|
||||
vision_config,
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
ncols, nrows = get_pixtral_hf_image_feature_grid_size(
|
||||
vision_config,
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
tokens = ([image_token] * num_width_tokens +
|
||||
[image_break_token]) * num_height_tokens
|
||||
tokens[-1] = image_end_token
|
||||
tokens = ([image_token] * ncols + [image_break_token]) * nrows
|
||||
tokens[-1] = image_end_token
|
||||
|
||||
return "".join(tokens)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=get_replacement_pixtral,
|
||||
),
|
||||
]
|
||||
|
||||
max_image_tokens = get_max_llava_image_tokens(self.ctx)
|
||||
return "".join(tokens)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=[image_token_id] * max_image_tokens,
|
||||
)
|
||||
replacement=get_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
data = dummy_image_for_clip(vision_config, num_images)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
data = dummy_image_for_siglip(vision_config, num_images)
|
||||
elif isinstance(vision_config, PixtralVisionConfig):
|
||||
data = dummy_image_for_pixtral_hf(vision_config, num_images)
|
||||
else:
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
def _build_llava_or_pixtral_hf_processor(
|
||||
ctx: InputProcessingContext,
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
enable_sanity_checks: bool = True,
|
||||
) -> BaseLlavaMultiModalProcessor:
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_token = hf_processor.image_token
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images,
|
||||
mm_data=data,
|
||||
if isinstance(hf_config.vision_config, PixtralVisionConfig):
|
||||
return PixtralHFMultiModalProcessor(
|
||||
ctx,
|
||||
cache=cache,
|
||||
enable_sanity_checks=enable_sanity_checks,
|
||||
)
|
||||
|
||||
|
||||
class LlavaLikeConfig(Protocol):
|
||||
vision_config: PretrainedConfig
|
||||
vision_feature_layer: Union[int, List[int]]
|
||||
return LlavaMultiModalProcessor(
|
||||
ctx,
|
||||
cache=cache,
|
||||
enable_sanity_checks=enable_sanity_checks,
|
||||
)
|
||||
|
||||
|
||||
def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
|
||||
@@ -330,7 +421,7 @@ def init_vision_tower_for_llava(
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
|
||||
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor)
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
@@ -596,7 +687,12 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
) -> MultiModalInputsV2:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
image_token_id = hf_config.image_token_index
|
||||
max_image_tokens = get_max_llava_image_tokens(self.ctx)
|
||||
|
||||
# Assume that it doesn't depend on the image size
|
||||
num_image_tokens = self._get_num_image_tokens(
|
||||
image_width=-1,
|
||||
image_height=-1,
|
||||
)
|
||||
|
||||
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
@@ -609,14 +705,14 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
def get_replacement_mantis(item_idx: int):
|
||||
return "".join([
|
||||
f"(image {item_idx+1}: <Image>", # 7 tokens
|
||||
"<image>" * max_image_tokens,
|
||||
"<image>" * num_image_tokens,
|
||||
"</Image>)", # 3 tokens
|
||||
])
|
||||
|
||||
mantis_repls = self._bind_prompt_replacements([
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id] * max_image_tokens,
|
||||
target=[image_token_id] * num_image_tokens,
|
||||
replacement=get_replacement_mantis,
|
||||
)
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user