[V1][VLM] V1 support for selected single-image models. (#11632)

Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Roger Wang
2024-12-31 13:17:22 -08:00
committed by GitHub
parent 8c3230d8c1
commit e7c7c5e822
19 changed files with 575 additions and 621 deletions

View File

@@ -3,16 +3,15 @@ from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch import nn
from transformers import ChameleonConfig, ChameleonVQVAEConfig
from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
ChameleonVQVAEConfig)
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@@ -29,11 +28,13 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once
from .interfaces import SupportsMultiModal, SupportsPP
@@ -45,10 +46,6 @@ from .utils import (is_pp_missing_parameter,
# and processor files, so we hardcode them in the model file for now.
CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512
CHAMELEON_IMAGE_SEQ_LENGTH = 1024
CHAMELEON_IMAGE_TOKEN_ID = 8711
CHAMELEON_IMAGE_START_TOKEN_ID = 8197
CHAMELEON_IMAGE_END_TOKEN_ID = 8196
CHAMELEON_SEP_TOKEN_ID = 8710
class ChameleonImagePixelInputs(TypedDict):
@@ -61,99 +58,75 @@ def get_max_chameleon_image_tokens(ctx: InputContext):
return CHAMELEON_IMAGE_SEQ_LENGTH
def dummy_seq_data_for_chameleon(
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = CHAMELEON_IMAGE_SEQ_LENGTH
else:
image_feature_size = image_feature_size_override
class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
), {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def _get_hf_processor(self) -> ChameleonProcessor:
return self.ctx.get_hf_processor(ChameleonProcessor)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def dummy_image_for_chameleon(
num_images: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = CHAMELEON_CROP_SIZE_WIDTH
height = CHAMELEON_CROP_SIZE_HEIGHT
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
processor = self._get_hf_processor()
image = Image.new("RGB", (width, height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}
return [
PromptReplacement(
modality="image",
target="<image>",
replacement="".join([
processor.image_start_token,
processor.image_token * CHAMELEON_IMAGE_SEQ_LENGTH,
processor.image_end_token,
]),
)
]
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
def dummy_data_for_chameleon(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
mm_data = {
"image":
self._get_dummy_images(width=CHAMELEON_CROP_SIZE_WIDTH,
height=CHAMELEON_CROP_SIZE_HEIGHT,
num_images=num_images)
}
seq_data, ranges = dummy_seq_data_for_chameleon(
seq_len,
num_images,
image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
)
return ProcessorInputs(
prompt_text="<image>" * num_images,
mm_data=mm_data,
)
mm_data = dummy_image_for_chameleon(num_images)
return DummyData(seq_data, mm_data, ranges)
def apply(
self,
prompt_text: str,
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
# Only <image> tokens should be considered as placeholders,
# so we ignore the image_start_token and image_end_token
result["mm_placeholders"] = {
modality: [
PlaceholderRange(offset=p["offset"] + 1,
length=p["length"] - 2) for p in ps
]
for modality, ps in result["mm_placeholders"].items()
}
def input_processor_for_chameleon(ctx: InputContext,
inputs: DecoderOnlyInputs):
"""
Processing input prompt to insert required tokens for image placeholder.
See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58
""" # noqa
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
if "multi_modal_placeholders" in inputs and "image" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID,
repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH,
pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID,
pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID,
)
# Appending sep token for chat mode to follow default processor
# behavior
if new_prompt is not None:
new_prompt += tokenizer.sep_token
new_token_ids += [CHAMELEON_SEP_TOKEN_ID]
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
return result
class ChameleonLayerNorm(nn.LayerNorm):
@@ -736,7 +709,7 @@ class ChameleonVQVAEEncoder(nn.Module):
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
hidden_state = self.down[i_level].block[i_block](
hidden_states[-1], )
hidden_states[-1])
if len(self.down[i_level].attn) > 0:
hidden_state = self.down[i_level].attn[i_block](
hidden_state)
@@ -925,10 +898,8 @@ class ChameleonModel(nn.Module):
return hidden_states
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon)
@INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon)
@MULTIMODAL_REGISTRY.register_processor(ChameleonMultiModalProcessor)
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):