[Core][VLM] Add precise multi-modal placeholder tracking (#8346)

Signed-off-by: Peter Salas <peter@fixie.ai>
This commit is contained in:
Peter Salas
2024-11-01 16:21:10 -07:00
committed by GitHub
parent d151fde834
commit 6c0b7f548d
53 changed files with 913 additions and 281 deletions

View File

@@ -9,13 +9,14 @@ from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import IntermediateTensors, SequenceData
from .blip import (BlipVisionModel, dummy_image_for_blip,
@@ -425,7 +426,11 @@ def dummy_seq_data_for_blip2(
return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)
), {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
@@ -434,7 +439,7 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_blip2(
seq_data, ranges = dummy_seq_data_for_blip2(
hf_config,
seq_len,
num_images,
@@ -444,7 +449,7 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
if isinstance(vision_config, Blip2VisionConfig):
mm_data = dummy_image_for_blip(vision_config, num_images)
return seq_data, mm_data
return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)