[Core][VLM] Add precise multi-modal placeholder tracking (#8346)

Signed-off-by: Peter Salas <peter@fixie.ai>
This commit is contained in:
Peter Salas
2024-11-01 16:21:10 -07:00
committed by GitHub
parent d151fde834
commit 6c0b7f548d
53 changed files with 913 additions and 281 deletions

View File

@@ -27,8 +27,8 @@ from transformers import FuyuConfig, FuyuImageProcessor
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
@@ -37,9 +37,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
@@ -103,7 +105,11 @@ def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
return SequenceData(token_ids), {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_fuyu(
@@ -119,15 +125,15 @@ def dummy_image_for_fuyu(
def dummy_data_for_fuyu(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_fuyu(ctx, seq_len, num_images)
seq_data, ranges = dummy_seq_data_for_fuyu(ctx, seq_len, num_images)
mm_data = dummy_image_for_fuyu(num_images,
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT)
return seq_data, mm_data
return DummyData(seq_data, mm_data, ranges)
def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
data: Image.Image):
data: List[Image.Image]):
image_encoding = image_processor.preprocess(data, return_tensors="pt")
batch_images = torch.stack([img[0] for img in image_encoding["images"]
]).unsqueeze(1)
@@ -158,8 +164,10 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
model_config = ctx.model_config
image_data = multi_modal_data["image"]
new_multi_modal_data = {}
image_list = image_data if isinstance(image_data, list) else [image_data]
# process image data
if isinstance(image_data, Image.Image):
if is_list_of(image_list, Image.Image):
# Fuyu's image_processor can also finish token padding
image_processor: FuyuImageProcessor = cached_get_image_processor(
model_config.model)
@@ -171,7 +179,7 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
])
new_multi_modal_data["image"] = image_patches
elif isinstance(image_data, torch.Tensor):
elif is_list_of(image_list, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
@@ -198,12 +206,13 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
def input_mapper_for_fuyu(ctx: InputContext, data: object):
model_config = ctx.model_config
if isinstance(data, Image.Image):
data_list = data if isinstance(data, list) else [data]
if is_list_of(data_list, Image.Image):
# Fuyu's image_processor can also finish token padding
image_processor: FuyuImageProcessor = cached_get_image_processor(
model_config.model)
model_image_input = _fuyu_image_preprocess(image_processor, data)
model_image_input = _fuyu_image_preprocess(image_processor, data_list)
data = torch.stack([
image_patch[0]
for image_patch in model_image_input["image_patches"]