[Core][VLM] Add precise multi-modal placeholder tracking (#8346)

Signed-off-by: Peter Salas <peter@fixie.ai>
This commit is contained in:
Peter Salas
2024-11-01 16:21:10 -07:00
committed by GitHub
parent d151fde834
commit 6c0b7f548d
53 changed files with 913 additions and 281 deletions

View File

@@ -10,7 +10,8 @@ from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
@@ -111,7 +112,7 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
image_feature_size = get_max_llava_image_tokens(ctx)
if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip(
seq_data, ranges = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_images,
@@ -120,9 +121,9 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
)
mm_data = dummy_image_for_clip(vision_config, num_images)
return seq_data, mm_data
return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip(
seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_images,
@@ -131,9 +132,9 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
)
mm_data = dummy_image_for_siglip(vision_config, num_images)
return seq_data, mm_data
return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, PixtralVisionConfig):
seq_data = dummy_seq_data_for_pixtral_hf(
seq_data, ranges = dummy_seq_data_for_pixtral_hf(
vision_config,
seq_len,
num_images,
@@ -142,7 +143,7 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
)
mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
return seq_data, mm_data
return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)