[Model] Update Paligemma multimodal processing with PromptUpdate (#14015)
Signed-off-by: Kyle Huang <kylhuang@nvidia.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -5,22 +5,26 @@ from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple,
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PaliGemmaConfig
|
||||
from transformers import BatchFeature, PaliGemmaConfig
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptIndexTargets,
|
||||
PromptInsertion, PromptReplacement,
|
||||
PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP, SupportsV0Only
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel, get_max_siglip_image_tokens
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
@@ -46,79 +50,6 @@ PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
|
||||
PaliGemmaImageEmbeddingInputs]
|
||||
|
||||
|
||||
def get_max_paligemma_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
return get_max_siglip_image_tokens(vision_config)
|
||||
|
||||
|
||||
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
seq_data, ranges = dummy_seq_data_for_siglip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_siglip(vision_config, num_images)
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
|
||||
|
||||
def input_processor_for_paligemma(ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs):
|
||||
|
||||
"""
|
||||
The correct prompt format needs to be:
|
||||
'<image>' * image_feature_size + '<bos>' + prompt + '\n'
|
||||
|
||||
See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
|
||||
""" # noqa
|
||||
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
||||
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
image_feature_size = hf_config.text_config.num_image_tokens
|
||||
image_token_str = tokenizer.decode(hf_config.image_token_index)
|
||||
bos_token = tokenizer.decode(hf_config.bos_token_id)
|
||||
image_token_str_pad = image_token_str * image_feature_size
|
||||
image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
|
||||
|
||||
orig_prompt = inputs.get("prompt")
|
||||
orig_prompt_ids = inputs.get("prompt_token_ids")
|
||||
|
||||
if orig_prompt is not None and image_token_str in orig_prompt:
|
||||
logger.warning(
|
||||
"The image token '%s' was detected in the prompt and "
|
||||
"will be removed. Please follow the proper prompt format"
|
||||
" documented on HuggingFace.", image_token_str)
|
||||
orig_prompt = orig_prompt.replace(image_token_str, "")
|
||||
orig_prompt_ids.remove(hf_config.image_token_index)
|
||||
|
||||
new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
|
||||
|
||||
# The PaliGemma 2 tokenizer does not include a starting BOS token
|
||||
if orig_prompt_ids[0] != hf_config.bos_token_id:
|
||||
orig_prompt_ids = [hf_config.bos_token_id] + orig_prompt_ids
|
||||
|
||||
new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
|
||||
class PaliGemmaMultiModalProjector(nn.Module):
|
||||
|
||||
def __init__(self, vision_hidden_size: int, projection_dim: int):
|
||||
@@ -131,12 +62,140 @@ class PaliGemmaMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
|
||||
class PaliGemmaProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(PaliGemmaConfig)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {"image": self.get_num_image_tokens()}
|
||||
|
||||
def get_num_image_tokens(self) -> int:
|
||||
hf_config = self.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
return get_max_siglip_image_tokens(vision_config)
|
||||
|
||||
|
||||
class PaliGemmaDummyInputsBuilder(
|
||||
BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.info.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
max_image_size = vision_config.image_size
|
||||
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=max_image_size,
|
||||
height=max_image_size,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="",
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class PaliGemmaMultiModalProcessor(
|
||||
BaseMultiModalProcessor[PaliGemmaProcessingInfo]):
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
if not mm_data:
|
||||
prompt_ids = tokenizer.encode(prompt)
|
||||
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
||||
|
||||
return super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
num_image_tokens = self.info.get_num_image_tokens()
|
||||
image_tokens = [image_token_id] * num_image_tokens
|
||||
|
||||
bos_token_id = tokenizer.bos_token_id
|
||||
assert isinstance(bos_token_id, int)
|
||||
|
||||
# Paligemma 1 and 2 have different tokenizer.add_bos_token
|
||||
# Insert <image>*n + <bos> after <bos> for Paligemma 1
|
||||
# Insert <image>*n + <bos> for Paligemma 2
|
||||
return [
|
||||
PromptInsertion(
|
||||
modality="image",
|
||||
target=PromptIndexTargets.prefix(
|
||||
[bos_token_id] if tokenizer.add_bos_token else []),
|
||||
insertion=PromptUpdateDetails(
|
||||
full=image_tokens + [bos_token_id],
|
||||
features=image_tokens,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputs:
|
||||
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
newline_prompt = "\n"
|
||||
newline_token_id = tokenizer.encode(newline_prompt)[-1] # 108
|
||||
# Force to add newline at the end of prompt for paligemma's format
|
||||
# This step can NOT be replacemented by current PromptUpdate methods
|
||||
if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id:
|
||||
prompt_token_ids.append(newline_token_id)
|
||||
mm_inputs["prompt_token_ids"] = prompt_token_ids
|
||||
mm_inputs["prompt"] += newline_prompt
|
||||
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
PaliGemmaMultiModalProcessor,
|
||||
info=PaliGemmaProcessingInfo,
|
||||
dummy_inputs=PaliGemmaDummyInputsBuilder)
|
||||
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP, SupportsV0Only):
|
||||
SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
|
||||
Reference in New Issue
Block a user