[V1][VLM] V1 support for selected single-image models. (#11632)
Signed-off-by: Roger Wang <ywang@roblox.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -1,15 +1,15 @@
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Set, Tuple, TypedDict, Union
|
||||
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
||||
Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.init import trunc_normal_
|
||||
from transformers import LlamaConfig
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.inputs import INPUT_REGISTRY, token_inputs
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@@ -17,30 +17,27 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
get_compressed_tensors_cache_scale)
|
||||
from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput,
|
||||
SamplingMetadata)
|
||||
from vllm.model_executor.layers.sampler import (SamplerOutput,
|
||||
SamplingMetadata, get_sampler)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.idefics2_vision_model import (
|
||||
Idefics2VisionTransformer)
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaMLP,
|
||||
LlamaModel)
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
|
||||
is_pp_missing_parameter,
|
||||
maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import cached_get_image_processor
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.aria import (AriaMoELMConfig,
|
||||
AriaVisionConfig)
|
||||
|
||||
from .utils import flatten_bn
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
is_pp_missing_parameter, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
|
||||
class AriaImagePixelInputs(TypedDict):
|
||||
@@ -251,7 +248,7 @@ class AriaProjector(nn.Module):
|
||||
class AriaFusedMoE(FusedMoE):
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
||||
shard_id: str) -> Set[str]:
|
||||
shard_id: str) -> None:
|
||||
# Override the weight_loader to handle the expert weights in the Aria
|
||||
# model, which are already packed with experts, and merge the gate and
|
||||
# up weights for each expert.
|
||||
@@ -346,7 +343,7 @@ class MoEDecoderLayer(LlamaDecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
config: AriaMoELMConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
@@ -434,7 +431,7 @@ class AriaMoELMModel(LlamaModel):
|
||||
return loaded_params
|
||||
|
||||
|
||||
def build_mm_projector(config):
|
||||
def build_mm_projector(config: PretrainedConfig):
|
||||
return AriaProjector(
|
||||
patch_to_query_dict=config.projector_patch_to_query_dict,
|
||||
embed_dim=config.vision_config.hidden_size,
|
||||
@@ -445,75 +442,70 @@ def build_mm_projector(config):
|
||||
)
|
||||
|
||||
|
||||
def get_max_multimodal_tokens(ctx):
|
||||
return max(ctx.model_config.hf_config.image_size2tokens.values())
|
||||
def get_max_aria_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config()
|
||||
return max(hf_config.projector_patch_to_query_dict.values())
|
||||
|
||||
|
||||
def input_mapper_for_aria(ctx, data):
|
||||
return MultiModalKwargs(data)
|
||||
class AriaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
|
||||
def input_processor(ctx, llm_inputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
# if it is pure text input, use it as is
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
image_processor = cached_get_image_processor(
|
||||
model_config.model, trust_remote_code=model_config.trust_remote_code)
|
||||
hf_config = model_config.hf_config
|
||||
|
||||
# prepare image tokens, the max_image_size is used to determine the number
|
||||
# of patch_size for every image
|
||||
max_image_size = multi_modal_data.pop("max_image_size", 980)
|
||||
_split_image = multi_modal_data.pop("split_image", False)
|
||||
|
||||
assert isinstance(max_image_size,
|
||||
(int, float)), "max_image_size should be float or int"
|
||||
images = (multi_modal_data["image"] if isinstance(
|
||||
multi_modal_data["image"], list) else [multi_modal_data["image"]])
|
||||
|
||||
image_inputs = image_processor.preprocess(images,
|
||||
max_image_size=max_image_size,
|
||||
split_image=_split_image,
|
||||
return_tensors="pt").data
|
||||
image_inputs['pixel_values'] = image_inputs['pixel_values'].to(
|
||||
ctx.model_config.dtype)
|
||||
num_crops = image_inputs.pop("num_crops")
|
||||
|
||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||
if num_crops.sum().item() > 0:
|
||||
_, prompt_token_ids, _ = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
None,
|
||||
prompt_token_ids,
|
||||
placeholder_token_id=hf_config.image_token_index,
|
||||
repeat_count=num_crops,
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
pixel_mask=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
repeat_count = [hf_config.image_size2tokens[max_image_size]
|
||||
] * sum(num_crops).item()
|
||||
new_prompt, new_token_ids, _ = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
None,
|
||||
prompt_token_ids,
|
||||
placeholder_token_id=hf_config.image_token_index,
|
||||
repeat_count=repeat_count,
|
||||
)
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.ctx.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
return token_inputs(
|
||||
prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data={"image": image_inputs},
|
||||
)
|
||||
max_image_tokens = get_max_aria_image_tokens(self.ctx)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=[image_token_id] * max_image_tokens,
|
||||
)
|
||||
]
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.ctx.get_hf_config()
|
||||
vision_config: AriaVisionConfig = hf_config.vision_config
|
||||
|
||||
max_image_size = vision_config.image_size
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=max_image_size,
|
||||
height=max_image_size,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_token: str = hf_processor.image_token # type: ignore
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_multimodal_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_aria)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_aria_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor)
|
||||
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
"""
|
||||
Aria model for conditional generation tasks.
|
||||
@@ -540,12 +532,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
# prepare the image_size to tokens mapping for the image preprocess, see
|
||||
# input_processor
|
||||
config.image_size2tokens = {
|
||||
int(math.sqrt(k) * config.vision_config.patch_size): v
|
||||
for k, v in config.projector_patch_to_query_dict.items()
|
||||
}
|
||||
self.config = config
|
||||
self.vision_tower = AriaVisionModel(config.vision_config)
|
||||
self.multi_modal_projector = build_mm_projector(config)
|
||||
@@ -566,7 +552,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
self.vocab_size, logit_scale)
|
||||
self.sampler = Sampler()
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def _validate_image_sizes(
|
||||
self, images: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
@@ -588,7 +574,12 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
|
||||
pixel_values = self._validate_image_sizes(pixel_values)
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
|
||||
if pixel_mask is not None:
|
||||
if not isinstance(pixel_mask, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel mask. "
|
||||
f"Got type: {type(pixel_mask)}")
|
||||
|
||||
pixel_mask = flatten_bn(pixel_mask, concat=True)
|
||||
|
||||
return AriaImagePixelInputs(
|
||||
|
||||
@@ -4,22 +4,16 @@ from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import Blip2VisionConfig, BlipVisionConfig
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import DecoderOnlyInputs, token_inputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
|
||||
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
@@ -33,92 +27,6 @@ def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
|
||||
return grid_length * grid_length
|
||||
|
||||
|
||||
def get_blip_image_feature_size(
|
||||
hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int:
|
||||
return get_blip_num_patches(image_size=hf_config.image_size,
|
||||
patch_size=hf_config.patch_size)
|
||||
|
||||
|
||||
def get_max_blip_image_tokens(
|
||||
hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int:
|
||||
return get_blip_image_feature_size(hf_config)
|
||||
|
||||
|
||||
def dummy_seq_data_for_blip(
|
||||
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
|
||||
seq_len: int,
|
||||
num_images: int,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
):
|
||||
if image_feature_size_override is None:
|
||||
image_feature_size = get_blip_image_feature_size(hf_config)
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
return SequenceData.from_prompt_token_counts(
|
||||
(image_token_id, image_feature_size * num_images),
|
||||
(0, seq_len - image_feature_size * num_images),
|
||||
)
|
||||
|
||||
|
||||
def dummy_image_for_blip(
|
||||
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
|
||||
num_images: int,
|
||||
*,
|
||||
image_width_override: Optional[int] = None,
|
||||
image_height_override: Optional[int] = None,
|
||||
):
|
||||
width = height = hf_config.image_size
|
||||
if image_width_override is not None:
|
||||
width = image_width_override
|
||||
if image_height_override is not None:
|
||||
height = image_height_override
|
||||
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
return {"image": image if num_images == 1 else [image] * num_images}
|
||||
|
||||
|
||||
def input_processor_for_blip(
|
||||
model_config: ModelConfig,
|
||||
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
|
||||
inputs: DecoderOnlyInputs,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
|
||||
if "multi_modal_placeholders" in inputs and "image" in inputs[
|
||||
"multi_modal_placeholders"]:
|
||||
# The inputs already have placeholders.
|
||||
return inputs
|
||||
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
|
||||
if image_feature_size_override is None:
|
||||
image_feature_size = get_blip_image_feature_size(hf_config)
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
inputs.get("prompt"),
|
||||
inputs["prompt_token_ids"],
|
||||
placeholder_token_id=image_token_id,
|
||||
repeat_count=image_feature_size,
|
||||
)
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data,
|
||||
multi_modal_placeholders={"image": ranges})
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
|
||||
class BlipVisionEmbeddings(nn.Module):
|
||||
|
||||
|
||||
@@ -4,32 +4,33 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
|
||||
apply_chunking_to_forward)
|
||||
from transformers import (BatchFeature, Blip2Config, Blip2Processor,
|
||||
Blip2QFormerConfig, apply_chunking_to_forward)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.utils import consecutive_placeholder_ranges
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .blip import (BlipVisionModel, dummy_image_for_blip,
|
||||
get_max_blip_image_tokens)
|
||||
from .blip import BlipVisionModel
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
# We use this internally as placeholders since there is no image token
|
||||
# defined on the HuggingFace repo
|
||||
BLIP2_IMAGE_TOKEN = "<image>"
|
||||
BLIP2_IMAGE_TOKEN_ID = 50265
|
||||
_IMAGE_TOKEN_ID = 50265
|
||||
|
||||
|
||||
class Blip2ImagePixelInputs(TypedDict):
|
||||
@@ -396,92 +397,87 @@ class Blip2QFormerModel(nn.Module):
|
||||
return sequence_output
|
||||
|
||||
|
||||
def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
|
||||
def get_max_blip2_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(Blip2Config)
|
||||
return hf_config.num_query_tokens
|
||||
|
||||
|
||||
def get_max_blip2_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(Blip2Config)
|
||||
vision_config = hf_config.vision_config
|
||||
class Blip2MultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
if isinstance(vision_config, Blip2VisionConfig):
|
||||
return get_max_blip_image_tokens(vision_config)
|
||||
def _get_hf_processor(self) -> Blip2Processor:
|
||||
return self.ctx.get_hf_processor(Blip2Processor)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
max_image_tokens = get_max_blip2_image_tokens(self.ctx)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target="</s>",
|
||||
replacement="<image>" * max_image_tokens + "</s>",
|
||||
)
|
||||
]
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputsV2:
|
||||
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
# Only <image> tokens should be considered as placeholders,
|
||||
# so we ignore the trailing bos_token
|
||||
result["mm_placeholders"] = {
|
||||
modality: [
|
||||
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
|
||||
for p in ps
|
||||
]
|
||||
for modality, ps in result["mm_placeholders"].items()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.ctx.get_hf_config(Blip2Config)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
max_image_size = vision_config.image_size
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=max_image_size,
|
||||
height=max_image_size,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="",
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
def dummy_seq_data_for_blip2(
|
||||
hf_config: Blip2Config,
|
||||
seq_len: int,
|
||||
num_images: int,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
):
|
||||
if image_feature_size_override is None:
|
||||
image_feature_size = get_blip2_image_feature_size(hf_config)
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
return SequenceData.from_prompt_token_counts(
|
||||
(image_token_id, image_feature_size * num_images),
|
||||
(0, seq_len - image_feature_size * num_images),
|
||||
), {
|
||||
"image":
|
||||
consecutive_placeholder_ranges(num_items=num_images,
|
||||
item_size=image_feature_size)
|
||||
}
|
||||
|
||||
|
||||
def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
hf_config = ctx.get_hf_config(Blip2Config)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
seq_data, ranges = dummy_seq_data_for_blip2(
|
||||
hf_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=BLIP2_IMAGE_TOKEN_ID,
|
||||
)
|
||||
|
||||
if isinstance(vision_config, Blip2VisionConfig):
|
||||
mm_data = dummy_image_for_blip(vision_config, num_images)
|
||||
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
|
||||
hf_config = ctx.get_hf_config(Blip2Config)
|
||||
image_feature_size = get_blip2_image_feature_size(hf_config)
|
||||
|
||||
# The original model places image tokens at the front
|
||||
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
|
||||
new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
|
||||
new_token_ids += inputs["prompt_token_ids"]
|
||||
|
||||
new_prompt = inputs.get("prompt")
|
||||
if new_prompt is not None:
|
||||
new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
|
||||
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
|
||||
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor)
|
||||
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@@ -627,7 +623,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
BLIP2_IMAGE_TOKEN_ID)
|
||||
_IMAGE_TOKEN_ID)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -3,16 +3,15 @@ from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from transformers import ChameleonConfig, ChameleonVQVAEConfig
|
||||
from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
|
||||
ChameleonVQVAEConfig)
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
@@ -29,11 +28,13 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
consecutive_placeholder_ranges,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
@@ -45,10 +46,6 @@ from .utils import (is_pp_missing_parameter,
|
||||
# and processor files, so we hardcode them in the model file for now.
|
||||
CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512
|
||||
CHAMELEON_IMAGE_SEQ_LENGTH = 1024
|
||||
CHAMELEON_IMAGE_TOKEN_ID = 8711
|
||||
CHAMELEON_IMAGE_START_TOKEN_ID = 8197
|
||||
CHAMELEON_IMAGE_END_TOKEN_ID = 8196
|
||||
CHAMELEON_SEP_TOKEN_ID = 8710
|
||||
|
||||
|
||||
class ChameleonImagePixelInputs(TypedDict):
|
||||
@@ -61,99 +58,75 @@ def get_max_chameleon_image_tokens(ctx: InputContext):
|
||||
return CHAMELEON_IMAGE_SEQ_LENGTH
|
||||
|
||||
|
||||
def dummy_seq_data_for_chameleon(
|
||||
seq_len: int,
|
||||
num_images: int,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
):
|
||||
if image_feature_size_override is None:
|
||||
image_feature_size = CHAMELEON_IMAGE_SEQ_LENGTH
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
return SequenceData.from_prompt_token_counts(
|
||||
(image_token_id, image_feature_size * num_images),
|
||||
(0, seq_len - image_feature_size * num_images),
|
||||
), {
|
||||
"image":
|
||||
consecutive_placeholder_ranges(num_items=num_images,
|
||||
item_size=image_feature_size)
|
||||
}
|
||||
def _get_hf_processor(self) -> ChameleonProcessor:
|
||||
return self.ctx.get_hf_processor(ChameleonProcessor)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def dummy_image_for_chameleon(
|
||||
num_images: int,
|
||||
*,
|
||||
image_width_override: Optional[int] = None,
|
||||
image_height_override: Optional[int] = None,
|
||||
):
|
||||
width = CHAMELEON_CROP_SIZE_WIDTH
|
||||
height = CHAMELEON_CROP_SIZE_HEIGHT
|
||||
if image_width_override is not None:
|
||||
width = image_width_override
|
||||
if image_height_override is not None:
|
||||
height = image_height_override
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
processor = self._get_hf_processor()
|
||||
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
return {"image": image if num_images == 1 else [image] * num_images}
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target="<image>",
|
||||
replacement="".join([
|
||||
processor.image_start_token,
|
||||
processor.image_token * CHAMELEON_IMAGE_SEQ_LENGTH,
|
||||
processor.image_end_token,
|
||||
]),
|
||||
)
|
||||
]
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
def dummy_data_for_chameleon(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
num_images = mm_counts["image"]
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=CHAMELEON_CROP_SIZE_WIDTH,
|
||||
height=CHAMELEON_CROP_SIZE_HEIGHT,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
seq_data, ranges = dummy_seq_data_for_chameleon(
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
|
||||
)
|
||||
return ProcessorInputs(
|
||||
prompt_text="<image>" * num_images,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_chameleon(num_images)
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
def apply(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputsV2:
|
||||
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
# Only <image> tokens should be considered as placeholders,
|
||||
# so we ignore the image_start_token and image_end_token
|
||||
result["mm_placeholders"] = {
|
||||
modality: [
|
||||
PlaceholderRange(offset=p["offset"] + 1,
|
||||
length=p["length"] - 2) for p in ps
|
||||
]
|
||||
for modality, ps in result["mm_placeholders"].items()
|
||||
}
|
||||
|
||||
def input_processor_for_chameleon(ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs):
|
||||
|
||||
"""
|
||||
Processing input prompt to insert required tokens for image placeholder.
|
||||
|
||||
See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58
|
||||
""" # noqa
|
||||
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
|
||||
if "multi_modal_placeholders" in inputs and "image" in inputs[
|
||||
"multi_modal_placeholders"]:
|
||||
# The inputs already have placeholders.
|
||||
return inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
inputs.get("prompt"),
|
||||
inputs["prompt_token_ids"],
|
||||
placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID,
|
||||
repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH,
|
||||
pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID,
|
||||
pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID,
|
||||
)
|
||||
|
||||
# Appending sep token for chat mode to follow default processor
|
||||
# behavior
|
||||
if new_prompt is not None:
|
||||
new_prompt += tokenizer.sep_token
|
||||
new_token_ids += [CHAMELEON_SEP_TOKEN_ID]
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
return result
|
||||
|
||||
|
||||
class ChameleonLayerNorm(nn.LayerNorm):
|
||||
@@ -736,7 +709,7 @@ class ChameleonVQVAEEncoder(nn.Module):
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
hidden_state = self.down[i_level].block[i_block](
|
||||
hidden_states[-1], )
|
||||
hidden_states[-1])
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
hidden_state = self.down[i_level].attn[i_block](
|
||||
hidden_state)
|
||||
@@ -925,10 +898,8 @@ class ChameleonModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon)
|
||||
@MULTIMODAL_REGISTRY.register_processor(ChameleonMultiModalProcessor)
|
||||
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
|
||||
@@ -15,32 +15,30 @@
|
||||
# limitations under the License.
|
||||
""" PyTorch Fuyu model."""
|
||||
import math
|
||||
from array import array
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from PIL import Image
|
||||
from transformers import FuyuImageProcessor
|
||||
from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
|
||||
FuyuProcessor)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.image import cached_get_image_processor
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
consecutive_placeholder_ranges)
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.parse import ImageProcessorItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
@@ -54,178 +52,193 @@ MAX_IMAGE_FEATURE_SIZE_HEIGHT = 1080
|
||||
MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920
|
||||
|
||||
|
||||
class FuyuImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
class FuyuImagePatchInputs(TypedDict):
|
||||
type: Literal["image_patches"]
|
||||
data: torch.Tensor
|
||||
"""
|
||||
Shape:
|
||||
(batch_size, num_patches, patch_size_x * patch_size_y * num_channels)
|
||||
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
|
||||
"""
|
||||
|
||||
patches_per_image: List[int]
|
||||
"""
|
||||
List of number of total patches for each image in the batch.
|
||||
This is used to restore the first two dimensions of `data`.
|
||||
"""
|
||||
|
||||
|
||||
def _calculate_num_image_tokens(
|
||||
height: int,
|
||||
width: int,
|
||||
def _get_fuyu_num_image_tokens(
|
||||
image_height: int,
|
||||
image_width: int,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
calculate number of image tokens needed for a given image size
|
||||
The expected Fuyu image prompts is in format:
|
||||
Calculate the number of image tokens needed for a given image size.
|
||||
|
||||
The expected Fuyu image prompts can be expressed as:
|
||||
|
||||
.. code-block::
|
||||
(image_token * ncols + newline_token) * nrows
|
||||
args:
|
||||
image_size: Tuple[int, int] - (width, height) of the image
|
||||
returns:
|
||||
ncols: int - number of image tokens in x direction
|
||||
nrows: int - number of image tokens in y direction
|
||||
|
||||
Args:
|
||||
image_size: Tuple[int, int] - `(width, height)` of the image
|
||||
|
||||
Returns:
|
||||
ncols: int - number of image tokens in `x` direction
|
||||
nrows: int - number of image tokens in `y` direction
|
||||
"""
|
||||
ncol = math.ceil(width / 30)
|
||||
nrow = math.ceil(height / 30)
|
||||
return ncol, nrow
|
||||
|
||||
|
||||
def get_max_fuyu_image_feature_size():
|
||||
|
||||
return _calculate_num_image_tokens(
|
||||
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
)
|
||||
ncols = math.ceil(image_width / 30)
|
||||
nrows = math.ceil(image_height / 30)
|
||||
return ncols, nrows
|
||||
|
||||
|
||||
def get_max_fuyu_image_tokens(ctx: InputContext):
|
||||
ncol, nrow = get_max_fuyu_image_feature_size()
|
||||
return (ncol + 1) * nrow
|
||||
|
||||
|
||||
def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
|
||||
ncol, nrow = get_max_fuyu_image_feature_size()
|
||||
image_feature_size = get_max_fuyu_image_tokens(ctx)
|
||||
|
||||
image_token_ids = (
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, [_IMAGE_TOKEN_ID]) * ncol +
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, [_NEWLINE_TOKEN_ID])) * nrow
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images
|
||||
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[0]) * (seq_len - image_feature_size * num_images)
|
||||
return SequenceData(token_ids), {
|
||||
"image":
|
||||
consecutive_placeholder_ranges(num_items=num_images,
|
||||
item_size=image_feature_size)
|
||||
}
|
||||
|
||||
|
||||
def dummy_image_for_fuyu(
|
||||
num_images: int,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
):
|
||||
image = Image.new("RGB", (image_width, image_height), color=0)
|
||||
return {"image": image if num_images == 1 else [image] * num_images}
|
||||
|
||||
|
||||
def dummy_data_for_fuyu(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
num_images = mm_counts["image"]
|
||||
seq_data, ranges = dummy_seq_data_for_fuyu(ctx, seq_len, num_images)
|
||||
mm_data = dummy_image_for_fuyu(num_images,
|
||||
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT)
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
|
||||
|
||||
def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
|
||||
data: List[Image.Image]):
|
||||
image_encoding = image_processor.preprocess(data, return_tensors="pt")
|
||||
batch_images = torch.stack([img[0] for img in image_encoding["images"]
|
||||
]).unsqueeze(1)
|
||||
image_unpadded_heights = torch.tensor(
|
||||
image_encoding["image_unpadded_heights"])
|
||||
image_unpadded_widths = torch.tensor(
|
||||
image_encoding["image_unpadded_widths"])
|
||||
|
||||
batch_size = len(image_encoding["images"])
|
||||
image_present = torch.ones(batch_size, 1, 1)
|
||||
model_image_input = image_processor.preprocess_with_tokenizer_info(
|
||||
image_input=batch_images,
|
||||
image_present=image_present,
|
||||
image_unpadded_h=image_unpadded_heights,
|
||||
image_unpadded_w=image_unpadded_widths,
|
||||
image_placeholder_id=_IMAGE_TOKEN_ID,
|
||||
image_newline_id=_NEWLINE_TOKEN_ID,
|
||||
variable_sized=True,
|
||||
ncols, nrows = _get_fuyu_num_image_tokens(
|
||||
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
)
|
||||
return model_image_input
|
||||
|
||||
return (ncols + 1) * nrows
|
||||
|
||||
|
||||
def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
model_config = ctx.model_config
|
||||
image_data = multi_modal_data["image"]
|
||||
new_multi_modal_data = {}
|
||||
image_list = image_data if isinstance(image_data, list) else [image_data]
|
||||
def _get_hf_processor(self) -> FuyuProcessor:
|
||||
return self.ctx.get_hf_processor(FuyuProcessor)
|
||||
|
||||
# process image data
|
||||
if is_list_of(image_list, Image.Image):
|
||||
# Fuyu's image_processor can also finish token padding
|
||||
image_processor: FuyuImageProcessor = cached_get_image_processor(
|
||||
model_config.model)
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
|
||||
model_image_input = _fuyu_image_preprocess(image_processor, image_data)
|
||||
image_patches = torch.cat([
|
||||
image_patch[0]
|
||||
for image_patch in model_image_input["image_patches"]
|
||||
])
|
||||
new_multi_modal_data["image"] = image_patches
|
||||
if not mm_data:
|
||||
# Avoid warning from HF logger for text-only input
|
||||
# Input_ids format: bos_token_id + prompt_token_ids + boa_token_id
|
||||
# Tokenizer won't add boa_token_id by default, we add it manually.
|
||||
tokenizer = self._get_tokenizer()
|
||||
boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore
|
||||
prompt_ids = tokenizer.encode(prompt) + [boa_token_id]
|
||||
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
||||
|
||||
elif is_list_of(image_list, torch.Tensor):
|
||||
raise NotImplementedError("Embeddings input is not supported yet")
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
# process prompts
|
||||
prompt = inputs.get("prompt")
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
tokenizer = cached_get_tokenizer(model_config.model)
|
||||
# dim0 is batch_size, dim1 is subseq_size which will always be 1
|
||||
image_input_ids: List[List[
|
||||
torch.Tensor]] = model_image_input["image_input_ids"]
|
||||
image_input_ids = image_input_ids[0][0].tolist()
|
||||
bos_token = tokenizer.encode("<s>", add_special_tokens=False)[1:]
|
||||
boa_token = tokenizer.encode("\x04", add_special_tokens=False)[1:]
|
||||
image_patches = processed_outputs.get("image_patches")
|
||||
if image_patches is not None:
|
||||
images = mm_data["images"]
|
||||
assert isinstance(images, list)
|
||||
|
||||
new_prompt = prompt + "\x04"
|
||||
new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[
|
||||
1:] + boa_token
|
||||
# Original output: (1, num_images, Pn, Px * Py * C)
|
||||
# New output: (num_images, Pn, Px * Py * C)
|
||||
assert (isinstance(image_patches, list)
|
||||
and len(image_patches) == 1)
|
||||
assert (isinstance(image_patches[0], torch.Tensor)
|
||||
and len(image_patches[0]) == len(images))
|
||||
|
||||
return token_inputs(prompt=new_prompt,
|
||||
prompt_token_ids=new_prompt_token_ids,
|
||||
multi_modal_data=new_multi_modal_data)
|
||||
processed_outputs["image_patches"] = image_patches[0]
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(image_patches=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.ctx.get_hf_config(FuyuConfig)
|
||||
bos_token_id = hf_config.bos_token_id
|
||||
|
||||
tokenizer = self._get_tokenizer()
|
||||
eot_token_id = tokenizer.bos_token_id
|
||||
assert isinstance(eot_token_id, int)
|
||||
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_processor: FuyuImageProcessor = hf_processor.image_processor
|
||||
target_size = image_processor.size
|
||||
target_height, target_width = (target_size["height"],
|
||||
target_size["width"])
|
||||
|
||||
def get_replacement_fuyu(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
width, height = image_size.width, image_size.height
|
||||
if not (width <= target_width and height <= target_height):
|
||||
height_scale_factor = target_height / height
|
||||
width_scale_factor = target_width / width
|
||||
optimal_scale_factor = min(height_scale_factor,
|
||||
width_scale_factor)
|
||||
|
||||
height = int(height * optimal_scale_factor)
|
||||
width = int(width * optimal_scale_factor)
|
||||
|
||||
ncols, nrows = _get_fuyu_num_image_tokens(
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
)
|
||||
|
||||
return (([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows +
|
||||
[bos_token_id])
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[eot_token_id],
|
||||
replacement=get_replacement_fuyu,
|
||||
)
|
||||
]
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputsV2:
|
||||
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
# Only |SPEAKER| (image) tokens should be considered as placeholders,
|
||||
# so we ignore the trailing bos_token_id
|
||||
result["mm_placeholders"] = {
|
||||
modality: [
|
||||
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
|
||||
for p in ps
|
||||
]
|
||||
for modality, ps in result["mm_placeholders"].items()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="",
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
def input_mapper_for_fuyu(ctx: InputContext, data: object):
|
||||
model_config = ctx.model_config
|
||||
data_list = data if isinstance(data, list) else [data]
|
||||
if is_list_of(data_list, Image.Image):
|
||||
# Fuyu's image_processor can also finish token padding
|
||||
image_processor: FuyuImageProcessor = cached_get_image_processor(
|
||||
model_config.model)
|
||||
|
||||
model_image_input = _fuyu_image_preprocess(image_processor, data_list)
|
||||
data = torch.stack([
|
||||
image_patch[0]
|
||||
for image_patch in model_image_input["image_patches"]
|
||||
])
|
||||
|
||||
# image has been processed with prompt in input processor
|
||||
return MultiModalKwargs({"pixel_values": data})
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
|
||||
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor)
|
||||
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@@ -280,28 +293,32 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return data.to(self.vision_embed_tokens.weight.dtype)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
|
||||
image_patches = kwargs.pop("image_patches", None)
|
||||
if image_patches is not None:
|
||||
if not isinstance(image_patches, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image patches. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
f"Got type: {type(image_patches)}")
|
||||
|
||||
return FuyuImagePixelInputs(
|
||||
type="pixel_values",
|
||||
image_patches_flat = flatten_bn(image_patches)
|
||||
|
||||
return FuyuImagePatchInputs(
|
||||
type="image_patches",
|
||||
data=self._validate_pixel_values(
|
||||
flatten_bn(pixel_values, concat=True)),
|
||||
flatten_bn(image_patches_flat, concat=True)),
|
||||
patches_per_image=[x.size(0) for x in image_patches_flat],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: FuyuImagePixelInputs) -> torch.Tensor:
|
||||
self, image_input: FuyuImagePatchInputs) -> NestedTensors:
|
||||
image_patches = image_input["data"]
|
||||
patches_per_image = image_input["patches_per_image"]
|
||||
|
||||
assert self.vision_embed_tokens is not None
|
||||
vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
|
||||
return vision_embeddings
|
||||
vision_embeddings, _ = self.vision_embed_tokens(image_patches)
|
||||
return vision_embeddings.split(patches_per_image, dim=0)
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
@@ -69,7 +69,8 @@ class Idefics2VisionEmbeddings(nn.Module):
|
||||
patch_attention_mask: torch.BoolTensor,
|
||||
tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor:
|
||||
batch_size, _, max_im_h, max_im_w = pixel_values.shape
|
||||
patch_embeds = self.patch_embedding(pixel_values)
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(target_dtype))
|
||||
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||||
max_nb_patches_h, max_nb_patches_w = (
|
||||
max_im_h // self.patch_size,
|
||||
@@ -309,7 +310,8 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
hidden_states = self.embeddings(
|
||||
pixel_values=pixel_values,
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
tgt_sizes=tgt_sizes)
|
||||
tgt_sizes=tgt_sizes,
|
||||
)
|
||||
encoder_outputs = self.encoder(hidden_states)
|
||||
last_hidden_state = self.post_layernorm(encoder_outputs)
|
||||
return last_hidden_state
|
||||
|
||||
@@ -144,8 +144,8 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
# Original output: (1, num_images, C, H, W)
|
||||
# New output: (num_images, C, H, W)
|
||||
assert (isinstance(pixel_values, list)
|
||||
and len(pixel_values) == 1
|
||||
and isinstance(pixel_values[0], list)
|
||||
and len(pixel_values) == 1)
|
||||
assert (isinstance(pixel_values[0], list)
|
||||
and len(pixel_values[0]) == len(images))
|
||||
|
||||
processed_outputs["pixel_values"] = pixel_values[0]
|
||||
|
||||
@@ -528,10 +528,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
stacked_image_features = self._image_pixels_to_features(
|
||||
self.vision_tower, stacked_pixel_values)
|
||||
|
||||
return [
|
||||
self.multi_modal_projector(image_features) for image_features in
|
||||
torch.split(stacked_image_features, num_patches_per_batch)
|
||||
]
|
||||
return torch.split(self.multi_modal_projector(stacked_image_features),
|
||||
num_patches_per_batch)
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import math
|
||||
from dataclasses import dataclass, fields
|
||||
from functools import cached_property
|
||||
from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -306,7 +306,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor],
|
||||
torch.Tensor]] = None,
|
||||
image_tokens: Optional[torch.Tensor] = None,
|
||||
) -> Optional[List[torch.Tensor]]:
|
||||
) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor]]:
|
||||
if images is None:
|
||||
return None, None
|
||||
|
||||
@@ -604,11 +604,11 @@ class VisionTransformer(nn.Module):
|
||||
return self.args.image_size // self.args.patch_size
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
def device(self) -> torch.types.Device:
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.device:
|
||||
def dtype(self) -> torch.dtype:
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
@property
|
||||
@@ -741,8 +741,8 @@ def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig,
|
||||
ratio = max(image_width / max_width, image_height / max_height)
|
||||
|
||||
if ratio > 1:
|
||||
image_width = int(numpy.ceil(image_width / ratio))
|
||||
image_height = int(numpy.ceil(image_height / ratio))
|
||||
image_width = int(math.ceil(image_width / ratio))
|
||||
image_height = int(math.ceil(image_height / ratio))
|
||||
|
||||
num_height_tokens, num_width_tokens = _get_pixtral_hf_num_image_tokens(
|
||||
(image_height, image_width),
|
||||
|
||||
@@ -23,7 +23,6 @@ from functools import cached_property
|
||||
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
||||
Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BatchFeature
|
||||
@@ -177,16 +176,19 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
audio_len = feature_extractor.chunk_length * sampling_rate
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
audio_count = mm_counts.get("audio", 0)
|
||||
audio = np.zeros(audio_len)
|
||||
data = {"audio": [audio] * audio_count}
|
||||
mm_data = {
|
||||
"audio":
|
||||
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="<|AUDIO|>" * audio_count,
|
||||
mm_data=data,
|
||||
prompt_text="<|AUDIO|>" * num_audios,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -29,7 +29,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from PIL import Image
|
||||
from transformers import BatchFeature
|
||||
from transformers.models.qwen2_vl import (Qwen2VLImageProcessor,
|
||||
Qwen2VLProcessor)
|
||||
@@ -882,12 +881,10 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_token: str = hf_processor.image_token
|
||||
image_processor = _get_image_processor(hf_processor)
|
||||
|
||||
data = {}
|
||||
image_token: str = hf_processor.image_token
|
||||
resized_height, resized_width = smart_resize(
|
||||
height=9999999,
|
||||
width=9999999,
|
||||
@@ -895,14 +892,18 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
min_pixels=image_processor.min_pixels,
|
||||
max_pixels=image_processor.max_pixels,
|
||||
)
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
dummy_image = Image.new("RGB", (resized_width, resized_height),
|
||||
color=0)
|
||||
data["image"] = [dummy_image] * num_images
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=resized_width,
|
||||
height=resized_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images,
|
||||
mm_data=data,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -188,16 +188,19 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
audio_len = feature_extractor.chunk_length * sampling_rate
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
audio_count = mm_counts.get("audio", 0)
|
||||
audio = np.zeros(audio_len)
|
||||
data = {"audio": [audio] * audio_count}
|
||||
mm_data = {
|
||||
"audio":
|
||||
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="<|audio|>" * audio_count,
|
||||
mm_data=data,
|
||||
prompt_text="<|audio|>" * num_audios,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user