[VLM] Merged multi-modal processor for LLaVA-NeXT (#11682)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -24,6 +24,8 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
resolve_visual_encoder_outputs)
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
from .vision import VisionEncoderInfo
|
||||
|
||||
|
||||
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
assert image_size % patch_size == 0
|
||||
@@ -149,6 +151,29 @@ def input_processor_for_clip(
|
||||
multi_modal_placeholders={"image": ranges})
|
||||
|
||||
|
||||
class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
return get_clip_image_feature_size(self.vision_config)
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
return get_max_clip_image_tokens(self.vision_config)
|
||||
|
||||
def get_num_patches(self) -> int:
|
||||
return get_clip_patch_grid_length(
|
||||
image_size=self.vision_config.image_size,
|
||||
patch_size=self.vision_config.patch_size,
|
||||
)
|
||||
|
||||
def get_image_size(self) -> int:
|
||||
return self.vision_config.image_size
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
|
||||
class CLIPVisionEmbeddings(nn.Module):
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
return ImageSize(width=target_size["width"],
|
||||
height=target_size["height"])
|
||||
|
||||
def _get_image_grid_size(
|
||||
def _get_image_feature_grid_size(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
@@ -99,7 +99,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
target_width, target_height = self._get_image_target_size()
|
||||
|
||||
max_ncols, max_nrows = self._get_image_grid_size(
|
||||
max_ncols, max_nrows = self._get_image_feature_grid_size(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
)
|
||||
@@ -172,7 +172,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
ncols, nrows = self._get_image_grid_size(
|
||||
ncols, nrows = self._get_image_feature_grid_size(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from abc import abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
|
||||
Protocol, Set, Tuple, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -12,7 +13,6 @@ from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@@ -23,23 +23,23 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import ImageProcessorItems
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement,
|
||||
InputProcessingContext,
|
||||
MultiModalDataItems, ProcessingCache,
|
||||
ProcessorInputs, PromptReplacement,
|
||||
full_groupby_modality)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||
get_max_clip_image_tokens)
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
|
||||
get_max_pixtral_hf_image_tokens,
|
||||
get_pixtral_hf_image_feature_size)
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
get_max_siglip_image_tokens)
|
||||
from .pixtral import (PixtralHFVisionModel,
|
||||
get_pixtral_hf_image_feature_grid_size)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import vision_encoder_info
|
||||
|
||||
|
||||
class LlavaImagePixelInputs(TypedDict):
|
||||
@@ -94,39 +94,167 @@ class LlavaMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def get_max_llava_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
num_image_tokens = get_max_clip_image_tokens(vision_config)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
num_image_tokens = get_max_siglip_image_tokens(vision_config)
|
||||
elif isinstance(vision_config, PixtralVisionConfig):
|
||||
num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
|
||||
else:
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
strategy = hf_config.vision_feature_select_strategy
|
||||
if strategy == "default":
|
||||
return num_image_tokens - 1
|
||||
elif strategy == "full":
|
||||
return num_image_tokens
|
||||
else:
|
||||
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||
class LlavaLikeConfig(Protocol):
|
||||
vision_config: Final[PretrainedConfig]
|
||||
vision_feature_select_strategy: Final[str]
|
||||
vision_feature_layer: Final[Union[int, List[int]]]
|
||||
|
||||
|
||||
class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def __init__(self,
|
||||
ctx: InputProcessingContext,
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
enable_sanity_checks: bool = True) -> None:
|
||||
super().__init__(ctx,
|
||||
cache=cache,
|
||||
enable_sanity_checks=enable_sanity_checks)
|
||||
|
||||
vision_config = self._get_hf_config().vision_config
|
||||
self._vision_encoder_info = vision_encoder_info(vision_config)
|
||||
|
||||
@abstractmethod
|
||||
def _get_hf_config(self) -> LlavaLikeConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
return {"image": get_max_llava_image_tokens(self.ctx)}
|
||||
def _apply_feature_select_strategy(
|
||||
self,
|
||||
strategy: str,
|
||||
encoder_num_image_tokens: int,
|
||||
) -> int:
|
||||
if strategy == "default":
|
||||
return encoder_num_image_tokens - 1
|
||||
if strategy == "full":
|
||||
return encoder_num_image_tokens
|
||||
|
||||
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
|
||||
return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor))
|
||||
msg = f"Unexpected feature select strategy: {strategy!r}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def _get_max_image_tokens(self) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
|
||||
return self._apply_feature_select_strategy(
|
||||
hf_config.vision_feature_select_strategy,
|
||||
self._vision_encoder_info.get_max_image_tokens(),
|
||||
)
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
return {"image": self._get_max_image_tokens()}
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_dummy_image_size(self) -> ImageSize:
|
||||
image_size = self._vision_encoder_info.get_image_size()
|
||||
return ImageSize(image_size, image_size)
|
||||
|
||||
@abstractmethod
|
||||
def _get_image_token(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_token = self._get_image_token()
|
||||
target_width, target_height = self._get_dummy_image_size()
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
|
||||
|
||||
def _get_hf_config(self) -> LlavaConfig:
|
||||
return self.ctx.get_hf_config(LlavaConfig)
|
||||
|
||||
def _get_hf_processor(self) -> LlavaProcessor:
|
||||
return self.ctx.get_hf_processor(LlavaProcessor)
|
||||
|
||||
def _get_image_token(self) -> str:
|
||||
return self._get_hf_processor().image_token
|
||||
|
||||
def _get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
|
||||
return self._apply_feature_select_strategy(
|
||||
hf_config.vision_feature_select_strategy,
|
||||
self._vision_encoder_info.get_num_image_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self._get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
images = mm_items.get_items(
|
||||
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
||||
|
||||
if isinstance(images, ImageEmbeddingItems):
|
||||
num_image_tokens = images.get_feature_size(item_idx)
|
||||
else:
|
||||
image_size = images.get_image_size(item_idx)
|
||||
num_image_tokens = self._get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
return [image_token_id] * num_image_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=get_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class PixtralHFMultiModalProcessor(BaseLlavaMultiModalProcessor):
|
||||
|
||||
def _get_hf_config(self) -> LlavaConfig:
|
||||
return self.ctx.get_hf_config(LlavaConfig)
|
||||
|
||||
def _get_hf_processor(self) -> PixtralProcessor:
|
||||
return self.ctx.get_hf_processor(PixtralProcessor)
|
||||
|
||||
def _get_image_token(self) -> str:
|
||||
return self._get_hf_processor().image_token
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
@@ -140,119 +268,82 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
# NOTE: pixel_values=None for MLlavaProcessor
|
||||
pixel_values = processed_outputs.get("pixel_values")
|
||||
if pixel_values is not None:
|
||||
images = mm_data["images"]
|
||||
assert isinstance(images, list)
|
||||
|
||||
if isinstance(self._get_hf_processor(), PixtralProcessor):
|
||||
# Original output: (1, num_images, C, H, W)
|
||||
# New output: (num_images, C, H, W)
|
||||
assert (isinstance(pixel_values, list)
|
||||
and len(pixel_values) == 1)
|
||||
assert (isinstance(pixel_values[0], list)
|
||||
and len(pixel_values[0]) == len(images))
|
||||
# Original output: (1, num_images, C, H, W)
|
||||
# New output: (num_images, C, H, W)
|
||||
assert (isinstance(pixel_values, list) and len(pixel_values) == 1)
|
||||
assert (isinstance(pixel_values[0], list)
|
||||
and len(pixel_values[0]) == len(images))
|
||||
|
||||
processed_outputs["pixel_values"] = pixel_values[0]
|
||||
processed_outputs["pixel_values"] = pixel_values[0]
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
hf_config = self._get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
processor = self._get_hf_processor()
|
||||
if isinstance(processor, PixtralProcessor):
|
||||
image_token = processor.image_token
|
||||
image_break_token = processor.image_break_token
|
||||
image_end_token = processor.image_end_token
|
||||
image_token = processor.image_token
|
||||
image_break_token = processor.image_break_token
|
||||
image_end_token = processor.image_end_token
|
||||
|
||||
vision_config = hf_config.vision_config
|
||||
assert isinstance(vision_config, PixtralVisionConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
assert isinstance(vision_config, PixtralVisionConfig)
|
||||
|
||||
def get_replacement_pixtral(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
def get_replacement(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
(
|
||||
num_width_tokens,
|
||||
num_height_tokens,
|
||||
) = get_pixtral_hf_image_feature_size(
|
||||
vision_config,
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
ncols, nrows = get_pixtral_hf_image_feature_grid_size(
|
||||
vision_config,
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
tokens = ([image_token] * num_width_tokens +
|
||||
[image_break_token]) * num_height_tokens
|
||||
tokens[-1] = image_end_token
|
||||
tokens = ([image_token] * ncols + [image_break_token]) * nrows
|
||||
tokens[-1] = image_end_token
|
||||
|
||||
return "".join(tokens)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=get_replacement_pixtral,
|
||||
),
|
||||
]
|
||||
|
||||
max_image_tokens = get_max_llava_image_tokens(self.ctx)
|
||||
return "".join(tokens)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=[image_token_id] * max_image_tokens,
|
||||
)
|
||||
replacement=get_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
data = dummy_image_for_clip(vision_config, num_images)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
data = dummy_image_for_siglip(vision_config, num_images)
|
||||
elif isinstance(vision_config, PixtralVisionConfig):
|
||||
data = dummy_image_for_pixtral_hf(vision_config, num_images)
|
||||
else:
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
def _build_llava_or_pixtral_hf_processor(
|
||||
ctx: InputProcessingContext,
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
enable_sanity_checks: bool = True,
|
||||
) -> BaseLlavaMultiModalProcessor:
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_token = hf_processor.image_token
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images,
|
||||
mm_data=data,
|
||||
if isinstance(hf_config.vision_config, PixtralVisionConfig):
|
||||
return PixtralHFMultiModalProcessor(
|
||||
ctx,
|
||||
cache=cache,
|
||||
enable_sanity_checks=enable_sanity_checks,
|
||||
)
|
||||
|
||||
|
||||
class LlavaLikeConfig(Protocol):
|
||||
vision_config: PretrainedConfig
|
||||
vision_feature_layer: Union[int, List[int]]
|
||||
return LlavaMultiModalProcessor(
|
||||
ctx,
|
||||
cache=cache,
|
||||
enable_sanity_checks=enable_sanity_checks,
|
||||
)
|
||||
|
||||
|
||||
def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
|
||||
@@ -330,7 +421,7 @@ def init_vision_tower_for_llava(
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
|
||||
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor)
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
@@ -596,7 +687,12 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
) -> MultiModalInputsV2:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
image_token_id = hf_config.image_token_index
|
||||
max_image_tokens = get_max_llava_image_tokens(self.ctx)
|
||||
|
||||
# Assume that it doesn't depend on the image size
|
||||
num_image_tokens = self._get_num_image_tokens(
|
||||
image_width=-1,
|
||||
image_height=-1,
|
||||
)
|
||||
|
||||
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
@@ -609,14 +705,14 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
def get_replacement_mantis(item_idx: int):
|
||||
return "".join([
|
||||
f"(image {item_idx+1}: <Image>", # 7 tokens
|
||||
"<image>" * max_image_tokens,
|
||||
"<image>" * num_image_tokens,
|
||||
"</Image>)", # 3 tokens
|
||||
])
|
||||
|
||||
mantis_repls = self._bind_prompt_replacements([
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id] * max_image_tokens,
|
||||
target=[image_token_id] * num_image_tokens,
|
||||
replacement=get_replacement_mantis,
|
||||
)
|
||||
])
|
||||
|
||||
@@ -4,31 +4,25 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig
|
||||
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
|
||||
from transformers.models.llava_next.modeling_llava_next import (
|
||||
get_anyres_image_grid_shape, unpad_image)
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
|
||||
from vllm.multimodal.parse import ImageSize
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||
dummy_seq_data_for_clip, get_clip_image_feature_size,
|
||||
get_clip_patch_grid_length, input_processor_for_clip)
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .llava import LlavaMultiModalProjector, init_vision_tower_for_llava
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
|
||||
get_siglip_patch_grid_length, input_processor_for_siglip)
|
||||
from .llava import (LlavaMultiModalProcessor, LlavaMultiModalProjector,
|
||||
init_vision_tower_for_llava)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
@@ -65,218 +59,127 @@ LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
|
||||
LlavaNextImageEmbeddingInputs]
|
||||
|
||||
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
|
||||
def _get_llava_next_num_unpadded_features(
|
||||
original_height: int,
|
||||
original_width: int,
|
||||
npatches: int,
|
||||
num_patch_height: int,
|
||||
num_patch_width: int,
|
||||
) -> Tuple[int, int]:
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
def _get_hf_config(self) -> LlavaNextConfig:
|
||||
return self.ctx.get_hf_config(LlavaNextConfig)
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(original_height * scale_factor)
|
||||
padding = (current_height - new_height) // 2
|
||||
current_height -= 2 * padding
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(original_width * scale_factor)
|
||||
padding = (current_width - new_width) // 2
|
||||
current_width -= 2 * padding
|
||||
def _get_hf_processor(self) -> LlavaNextProcessor:
|
||||
return self.ctx.get_hf_processor(LlavaNextProcessor)
|
||||
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
return (unpadded_features, newline_features)
|
||||
def _get_image_token(self) -> str:
|
||||
return self._get_hf_processor().image_token
|
||||
|
||||
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
|
||||
def get_llava_next_image_feature_size(
|
||||
hf_config: LlavaNextConfig,
|
||||
*,
|
||||
input_height: int,
|
||||
input_width: int,
|
||||
) -> int:
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
num_patches = get_clip_patch_grid_length(
|
||||
image_size=vision_config.image_size,
|
||||
patch_size=vision_config.patch_size,
|
||||
)
|
||||
base_feature_size = get_clip_image_feature_size(vision_config)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
num_patches = get_siglip_patch_grid_length(
|
||||
image_size=vision_config.image_size,
|
||||
patch_size=vision_config.patch_size,
|
||||
)
|
||||
base_feature_size = get_siglip_image_feature_size(vision_config)
|
||||
else:
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
strategy = hf_config.vision_feature_select_strategy
|
||||
if strategy == "default":
|
||||
base_feature_size -= 1
|
||||
elif strategy == "full":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_size=(input_height, input_width),
|
||||
grid_pinpoints=hf_config.image_grid_pinpoints,
|
||||
patch_size=vision_config.image_size,
|
||||
)
|
||||
|
||||
(
|
||||
unpadded_feature_size,
|
||||
newline_feature_size,
|
||||
) = _get_llava_next_num_unpadded_features(input_height, input_width,
|
||||
num_patches, num_patch_height,
|
||||
num_patch_width)
|
||||
|
||||
return unpadded_feature_size + newline_feature_size + base_feature_size
|
||||
|
||||
|
||||
def get_max_llava_next_image_tokens(ctx: InputContext):
|
||||
"""Compute the max feature size for all possible image grid pinpoints."""
|
||||
return _get_pinpoint_with_largest_features(ctx)[0]
|
||||
|
||||
|
||||
def _get_pinpoint_with_largest_features(
|
||||
ctx: InputContext) -> Tuple[int, Tuple[int, int]]:
|
||||
"""Get the grid pinpoint with the largest features & its feature size."""
|
||||
hf_config = ctx.get_hf_config(LlavaNextConfig)
|
||||
largest_feature_size = 0
|
||||
largest_feature_pinpoint = None
|
||||
for (height, width) in hf_config.image_grid_pinpoints:
|
||||
feat_size = get_llava_next_image_feature_size(
|
||||
hf_config,
|
||||
input_height=height,
|
||||
input_width=width,
|
||||
)
|
||||
if feat_size > largest_feature_size:
|
||||
largest_feature_size = feat_size
|
||||
largest_feature_pinpoint = (height, width)
|
||||
if not largest_feature_size or largest_feature_pinpoint is None:
|
||||
raise ValueError("Cannot have a largest feature size of 0!")
|
||||
return largest_feature_size, largest_feature_pinpoint
|
||||
|
||||
|
||||
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
hf_config = ctx.get_hf_config(LlavaNextConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
image_feature_size, pinpoint = _get_pinpoint_with_largest_features(ctx)
|
||||
max_feat_height, max_feat_width = pinpoint
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
seq_data, ranges = dummy_seq_data_for_clip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_sizes=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_clip(
|
||||
vision_config,
|
||||
num_images,
|
||||
image_width_override=max_feat_width,
|
||||
image_height_override=max_feat_height,
|
||||
def _get_max_image_tokens(self) -> int:
|
||||
largest_feature_size, _ = self._get_pinpoint_with_most_features()
|
||||
return largest_feature_size
|
||||
|
||||
def _get_dummy_image_size(self) -> ImageSize:
|
||||
_, pinpoint = self._get_pinpoint_with_most_features()
|
||||
return pinpoint
|
||||
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
|
||||
def _get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
|
||||
base_feature_size = self._apply_feature_select_strategy(
|
||||
hf_config.vision_feature_select_strategy,
|
||||
self._vision_encoder_info.get_num_image_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
),
|
||||
)
|
||||
num_patches = self._vision_encoder_info.get_num_patches()
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_size=(image_height, image_width),
|
||||
grid_pinpoints=hf_config.image_grid_pinpoints,
|
||||
patch_size=self._vision_encoder_info.get_image_size(),
|
||||
)
|
||||
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
seq_data, ranges = dummy_seq_data_for_siglip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
(
|
||||
unpadded_feature_size,
|
||||
newline_feature_size,
|
||||
) = self._get_num_unpadded_features(
|
||||
original_height=image_height,
|
||||
original_width=image_width,
|
||||
npatches=num_patches,
|
||||
num_patch_height=num_patch_height,
|
||||
num_patch_width=num_patch_width,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_siglip(
|
||||
vision_config,
|
||||
num_images,
|
||||
image_width_override=max_feat_width,
|
||||
image_height_override=max_feat_height,
|
||||
)
|
||||
return unpadded_feature_size + newline_feature_size + base_feature_size
|
||||
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
|
||||
def _get_num_unpadded_features(
|
||||
self,
|
||||
*,
|
||||
original_height: int,
|
||||
original_width: int,
|
||||
npatches: int,
|
||||
num_patch_height: int,
|
||||
num_patch_width: int,
|
||||
) -> tuple[int, int]:
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(original_height * scale_factor)
|
||||
padding = (current_height - new_height) // 2
|
||||
current_height -= 2 * padding
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(original_width * scale_factor)
|
||||
padding = (current_width - new_width) // 2
|
||||
current_width -= 2 * padding
|
||||
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
return (unpadded_features, newline_features)
|
||||
|
||||
def _get_pinpoint_with_most_features(self) -> tuple[int, ImageSize]:
|
||||
"""
|
||||
Get the grid pinpoint with the most features and
|
||||
the corresponding feature size.
|
||||
"""
|
||||
hf_config = self._get_hf_config()
|
||||
|
||||
largest_feature_size, largest_feature_pinpoint = 0, None
|
||||
for (height, width) in hf_config.image_grid_pinpoints:
|
||||
feat_size = self._get_num_image_tokens(image_width=width,
|
||||
image_height=height)
|
||||
if feat_size > largest_feature_size:
|
||||
largest_feature_size = feat_size
|
||||
largest_feature_pinpoint = ImageSize(width=width,
|
||||
height=height)
|
||||
|
||||
if largest_feature_size == 0 or largest_feature_pinpoint is None:
|
||||
raise ValueError("Cannot have a largest feature size of 0!")
|
||||
|
||||
return largest_feature_size, largest_feature_pinpoint
|
||||
|
||||
|
||||
def input_processor_for_llava_next(ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(LlavaNextConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
if isinstance(image_data, Image.Image):
|
||||
width, height = image_data.size
|
||||
|
||||
image_feature_size = get_llava_next_image_feature_size(
|
||||
hf_config,
|
||||
input_height=height,
|
||||
input_width=width,
|
||||
)
|
||||
elif is_list_of(image_data, Image.Image):
|
||||
image_feature_size = [
|
||||
get_llava_next_image_feature_size(hf_config,
|
||||
input_height=img.height,
|
||||
input_width=img.width)
|
||||
for img in image_data
|
||||
]
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
num_images, image_feature_size, hidden_size = image_data.shape
|
||||
elif is_list_of(image_data, torch.Tensor):
|
||||
image_feature_size = [item.shape[1] for item in image_data]
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return input_processor_for_clip(
|
||||
model_config,
|
||||
vision_config,
|
||||
inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
return input_processor_for_siglip(
|
||||
model_config,
|
||||
vision_config,
|
||||
inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor)
|
||||
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
@@ -507,7 +410,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def _process_image_pixels(
|
||||
self,
|
||||
inputs: LlavaNextImagePixelInputs,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = inputs["data"]
|
||||
|
||||
@@ -34,7 +34,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.parse import ImageProcessorItems
|
||||
from vllm.multimodal.parse import ImageEmbeddingItems, ImageProcessorItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement,
|
||||
@@ -388,15 +388,19 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
assert isinstance(bos_token_id, int)
|
||||
|
||||
def get_replacement_phi3v(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
images = mm_items.get_items(
|
||||
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
||||
|
||||
num_tokens = self._get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
if isinstance(images, ImageEmbeddingItems):
|
||||
num_image_tokens = images.get_feature_size(item_idx)
|
||||
else:
|
||||
image_size = images.get_image_size(item_idx)
|
||||
num_image_tokens = self._get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id]
|
||||
return [_IMAGE_TOKEN_ID] * num_image_tokens + [bos_token_id]
|
||||
|
||||
num_images = mm_items.get_count("image", strict=False)
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import VisionEncoderInfo
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
@@ -697,10 +698,18 @@ def get_pixtral_hf_patch_grid_length(*, image_size: int,
|
||||
return image_size // patch_size
|
||||
|
||||
|
||||
def get_pixtral_hf_num_patches(*, image_size: int, patch_size: int) -> int:
|
||||
grid_length = get_pixtral_hf_patch_grid_length(image_size=image_size,
|
||||
patch_size=patch_size)
|
||||
return grid_length * grid_length
|
||||
def get_pixtral_hf_image_feature_size(
|
||||
*,
|
||||
image_size: int,
|
||||
patch_size: int,
|
||||
) -> int:
|
||||
grid_length = get_pixtral_hf_patch_grid_length(
|
||||
image_size=image_size,
|
||||
patch_size=patch_size,
|
||||
)
|
||||
|
||||
# Consider the image_break_token
|
||||
return (grid_length + 1) * grid_length
|
||||
|
||||
|
||||
def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int:
|
||||
@@ -730,13 +739,16 @@ def dummy_image_for_pixtral_hf(
|
||||
return {"image": image if num_images == 1 else [image] * num_images}
|
||||
|
||||
|
||||
def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig,
|
||||
image_width: int,
|
||||
image_height: int) -> Tuple[int, int]:
|
||||
# Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501
|
||||
# https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180 # noqa: E501
|
||||
max_width, max_height = hf_config.image_size, hf_config.image_size
|
||||
patch_width, patch_height = hf_config.patch_size, hf_config.patch_size
|
||||
# Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501
|
||||
# https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180
|
||||
def get_pixtral_hf_image_feature_grid_size(
|
||||
hf_config: PixtralVisionConfig,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> tuple[int, int]:
|
||||
max_width = max_height = hf_config.image_size
|
||||
patch_width = patch_height = hf_config.patch_size
|
||||
|
||||
ratio = max(image_width / max_width, image_height / max_height)
|
||||
|
||||
@@ -744,12 +756,38 @@ def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig,
|
||||
image_width = int(math.ceil(image_width / ratio))
|
||||
image_height = int(math.ceil(image_height / ratio))
|
||||
|
||||
num_height_tokens, num_width_tokens = _get_pixtral_hf_num_image_tokens(
|
||||
nrows, ncols = _get_pixtral_hf_num_image_tokens(
|
||||
(image_height, image_width),
|
||||
(patch_height, patch_width),
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
return num_width_tokens, num_height_tokens
|
||||
return ncols, nrows
|
||||
|
||||
|
||||
class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
return get_pixtral_hf_image_feature_size(
|
||||
image_size=self.vision_config.image_size,
|
||||
patch_size=self.get_image_size(),
|
||||
)
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
return get_max_pixtral_hf_image_tokens(self.vision_config)
|
||||
|
||||
def get_num_patches(self) -> int:
|
||||
return get_pixtral_hf_patch_grid_length(
|
||||
image_size=self.vision_config.image_size,
|
||||
patch_size=self.vision_config.patch_size,
|
||||
)
|
||||
|
||||
def get_image_size(self) -> int:
|
||||
return self.vision_config.image_size
|
||||
|
||||
|
||||
class PixtralHFMLP(nn.Module):
|
||||
|
||||
@@ -28,6 +28,8 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
resolve_visual_encoder_outputs)
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
from .vision import VisionEncoderInfo
|
||||
|
||||
|
||||
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
# Since interpolation is applied, the image size need not be divisible
|
||||
@@ -156,6 +158,29 @@ def input_processor_for_siglip(
|
||||
multi_modal_placeholders={"image": ranges})
|
||||
|
||||
|
||||
class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
return get_siglip_image_feature_size(self.vision_config)
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
return get_max_siglip_image_tokens(self.vision_config)
|
||||
|
||||
def get_num_patches(self) -> int:
|
||||
return get_siglip_patch_grid_length(
|
||||
image_size=self.vision_config.image_size,
|
||||
patch_size=self.vision_config.patch_size,
|
||||
)
|
||||
|
||||
def get_image_size(self) -> int:
|
||||
return self.vision_config.image_size
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
|
||||
class SiglipVisionEmbeddings(nn.Module):
|
||||
|
||||
|
||||
@@ -373,7 +373,7 @@ def embed_multimodal(
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_token_id: int,
|
||||
get_text_embeds: Callable[[torch.Tensor], torch.Tensor],
|
||||
multimodal_embeds: Union[torch.Tensor, List[torch.Tensor]],
|
||||
multimodal_embeds: NestedTensors,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Embed token IDs and multimodal inputs and combine their embeddings.
|
||||
|
||||
52
vllm/model_executor/models/vision.py
Normal file
52
vllm/model_executor/models/vision.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
_C = TypeVar("_C", bound=PretrainedConfig)
|
||||
|
||||
|
||||
class VisionEncoderInfo(ABC, Generic[_C]):
|
||||
|
||||
def __init__(self, vision_config: _C) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.vision_config = vision_config
|
||||
|
||||
@abstractmethod
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_max_image_tokens(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_num_patches(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_image_size(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo:
|
||||
# Avoid circular imports
|
||||
from .clip import CLIPEncoderInfo, CLIPVisionConfig
|
||||
from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig
|
||||
from .siglip import SiglipEncoderInfo, SiglipVisionConfig
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return CLIPEncoderInfo(vision_config)
|
||||
if isinstance(vision_config, PixtralVisionConfig):
|
||||
return PixtralHFEncoderInfo(vision_config)
|
||||
if isinstance(vision_config, SiglipVisionConfig):
|
||||
return SiglipEncoderInfo(vision_config)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
Reference in New Issue
Block a user