[V1][VLM] V1 support for selected single-image models. (#11632)
Signed-off-by: Roger Wang <ywang@roblox.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -4,32 +4,33 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
|
||||
apply_chunking_to_forward)
|
||||
from transformers import (BatchFeature, Blip2Config, Blip2Processor,
|
||||
Blip2QFormerConfig, apply_chunking_to_forward)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.utils import consecutive_placeholder_ranges
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .blip import (BlipVisionModel, dummy_image_for_blip,
|
||||
get_max_blip_image_tokens)
|
||||
from .blip import BlipVisionModel
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
# We use this internally as placeholders since there is no image token
|
||||
# defined on the HuggingFace repo
|
||||
BLIP2_IMAGE_TOKEN = "<image>"
|
||||
BLIP2_IMAGE_TOKEN_ID = 50265
|
||||
_IMAGE_TOKEN_ID = 50265
|
||||
|
||||
|
||||
class Blip2ImagePixelInputs(TypedDict):
|
||||
@@ -396,92 +397,87 @@ class Blip2QFormerModel(nn.Module):
|
||||
return sequence_output
|
||||
|
||||
|
||||
def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
|
||||
def get_max_blip2_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(Blip2Config)
|
||||
return hf_config.num_query_tokens
|
||||
|
||||
|
||||
def get_max_blip2_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(Blip2Config)
|
||||
vision_config = hf_config.vision_config
|
||||
class Blip2MultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
if isinstance(vision_config, Blip2VisionConfig):
|
||||
return get_max_blip_image_tokens(vision_config)
|
||||
def _get_hf_processor(self) -> Blip2Processor:
|
||||
return self.ctx.get_hf_processor(Blip2Processor)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
max_image_tokens = get_max_blip2_image_tokens(self.ctx)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target="</s>",
|
||||
replacement="<image>" * max_image_tokens + "</s>",
|
||||
)
|
||||
]
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputsV2:
|
||||
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
# Only <image> tokens should be considered as placeholders,
|
||||
# so we ignore the trailing bos_token
|
||||
result["mm_placeholders"] = {
|
||||
modality: [
|
||||
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
|
||||
for p in ps
|
||||
]
|
||||
for modality, ps in result["mm_placeholders"].items()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.ctx.get_hf_config(Blip2Config)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
max_image_size = vision_config.image_size
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=max_image_size,
|
||||
height=max_image_size,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="",
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
def dummy_seq_data_for_blip2(
|
||||
hf_config: Blip2Config,
|
||||
seq_len: int,
|
||||
num_images: int,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
):
|
||||
if image_feature_size_override is None:
|
||||
image_feature_size = get_blip2_image_feature_size(hf_config)
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
return SequenceData.from_prompt_token_counts(
|
||||
(image_token_id, image_feature_size * num_images),
|
||||
(0, seq_len - image_feature_size * num_images),
|
||||
), {
|
||||
"image":
|
||||
consecutive_placeholder_ranges(num_items=num_images,
|
||||
item_size=image_feature_size)
|
||||
}
|
||||
|
||||
|
||||
def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
hf_config = ctx.get_hf_config(Blip2Config)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
seq_data, ranges = dummy_seq_data_for_blip2(
|
||||
hf_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=BLIP2_IMAGE_TOKEN_ID,
|
||||
)
|
||||
|
||||
if isinstance(vision_config, Blip2VisionConfig):
|
||||
mm_data = dummy_image_for_blip(vision_config, num_images)
|
||||
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
|
||||
hf_config = ctx.get_hf_config(Blip2Config)
|
||||
image_feature_size = get_blip2_image_feature_size(hf_config)
|
||||
|
||||
# The original model places image tokens at the front
|
||||
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
|
||||
new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
|
||||
new_token_ids += inputs["prompt_token_ids"]
|
||||
|
||||
new_prompt = inputs.get("prompt")
|
||||
if new_prompt is not None:
|
||||
new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
|
||||
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
|
||||
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor)
|
||||
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@@ -627,7 +623,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
BLIP2_IMAGE_TOKEN_ID)
|
||||
_IMAGE_TOKEN_ID)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user