[VLM] Fully dynamic prompt replacement in merged input processor (#11199)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -5,10 +5,10 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL.Image import Image
|
||||
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
|
||||
PixtralVisionConfig, PretrainedConfig,
|
||||
ProcessorMixin, SiglipVisionConfig)
|
||||
from transformers.models.llava import LlavaProcessor
|
||||
from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
@@ -21,11 +21,9 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
InputProcessingContext,
|
||||
ModalityProcessingMetadata,
|
||||
MultiModalProcessingMetadata,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@@ -33,7 +31,8 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||
get_max_clip_image_tokens)
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
|
||||
get_max_pixtral_hf_image_tokens)
|
||||
get_max_pixtral_hf_image_tokens,
|
||||
get_pixtral_hf_image_feature_size)
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
get_max_siglip_image_tokens)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
@@ -115,62 +114,7 @@ def get_max_llava_image_tokens(ctx: InputContext):
|
||||
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||
|
||||
|
||||
def dummy_mm_kwargs_for_llava(ctx: InputProcessingContext,
|
||||
mm_counts: Mapping[str, int]):
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
data = dummy_image_for_clip(vision_config, num_images)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
data = dummy_image_for_siglip(vision_config, num_images)
|
||||
elif isinstance(vision_config, PixtralVisionConfig):
|
||||
data = dummy_image_for_pixtral_hf(vision_config, num_images)
|
||||
else:
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
hf_processor = ctx.get_hf_processor()
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
|
||||
is_pixtral = isinstance(hf_processor, PixtralProcessor)
|
||||
|
||||
return MultiModalKwargs(
|
||||
**hf_inputs,
|
||||
is_pixtral=torch.tensor(is_pixtral),
|
||||
)
|
||||
|
||||
|
||||
def create_metadata_for_llava(
|
||||
ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
def get_repl_count(
|
||||
mm_items: list[Image],
|
||||
hf_inputs: BatchFeature,
|
||||
item_idx: int,
|
||||
) -> int:
|
||||
return get_max_llava_image_tokens(ctx)
|
||||
|
||||
return {
|
||||
"image":
|
||||
ModalityProcessingMetadata(prompt_repls=[
|
||||
PromptReplacement(target=[image_token_id],
|
||||
repl_unit=[image_token_id],
|
||||
repl_count=get_repl_count),
|
||||
]),
|
||||
}
|
||||
|
||||
|
||||
class LlavaProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def __init__(self, ctx: InputProcessingContext) -> None:
|
||||
super().__init__(
|
||||
ctx=ctx,
|
||||
metadata=create_metadata_for_llava(ctx),
|
||||
)
|
||||
class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
|
||||
if getattr(hf_processor, "__is_patched__", False):
|
||||
@@ -188,18 +132,72 @@ class LlavaProcessor(BaseMultiModalProcessor):
|
||||
|
||||
hf_processor.__is_patched__ = True # type: ignore
|
||||
|
||||
def _get_hf_processor(self) -> ProcessorMixin:
|
||||
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
|
||||
hf_processor = self.ctx.get_hf_processor()
|
||||
assert isinstance(hf_processor, (LlavaProcessor, PixtralProcessor))
|
||||
|
||||
if isinstance(hf_processor, PixtralProcessor):
|
||||
self._patch_pixtral_processor(hf_processor)
|
||||
|
||||
return hf_processor
|
||||
|
||||
def _get_dummy_mm_kwargs(
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
processor = self._get_hf_processor()
|
||||
if isinstance(processor, PixtralProcessor):
|
||||
image_token = processor.image_token
|
||||
image_break_token = processor.image_break_token
|
||||
image_end_token = processor.image_end_token
|
||||
|
||||
vision_config = hf_config.vision_config
|
||||
assert isinstance(vision_config, PixtralVisionConfig)
|
||||
|
||||
def get_replacement_pixtral(item_idx: int):
|
||||
image_size = mm_items.get_image_size(item_idx)
|
||||
(
|
||||
num_width_tokens,
|
||||
num_height_tokens,
|
||||
) = get_pixtral_hf_image_feature_size(
|
||||
vision_config,
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
tokens = ([image_token] * num_width_tokens +
|
||||
[image_break_token]) * num_height_tokens
|
||||
tokens[-1] = image_end_token
|
||||
|
||||
return "".join(tokens)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=get_replacement_pixtral,
|
||||
),
|
||||
]
|
||||
|
||||
max_image_tokens = get_max_llava_image_tokens(self.ctx)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=[image_token_id] * max_image_tokens,
|
||||
)
|
||||
]
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalKwargs:
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts["image"]
|
||||
@@ -215,11 +213,13 @@ class LlavaProcessor(BaseMultiModalProcessor):
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
hf_inputs = image_processor.preprocess(data['image'],
|
||||
return_tensors="pt")
|
||||
image_token = hf_processor.image_token
|
||||
|
||||
return MultiModalKwargs(**hf_inputs)
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images,
|
||||
mm_data=data,
|
||||
mm_processor_kwargs={},
|
||||
)
|
||||
|
||||
|
||||
class LlavaLikeConfig(Protocol):
|
||||
@@ -303,7 +303,7 @@ def init_vision_tower_for_llava(
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor)
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
@@ -584,7 +584,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
class MantisProcessor(LlavaProcessor):
|
||||
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
|
||||
def _get_hf_processor(self) -> ProcessorMixin:
|
||||
try:
|
||||
@@ -604,6 +604,6 @@ class MantisProcessor(LlavaProcessor):
|
||||
# To use this model, please use
|
||||
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(MantisProcessor)
|
||||
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor)
|
||||
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user