[Core][VLM] Add precise multi-modal placeholder tracking (#8346)
Signed-off-by: Peter Salas <peter@fixie.ai>
This commit is contained in:
@@ -10,7 +10,8 @@ from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext)
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
@@ -111,7 +112,7 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
|
||||
image_feature_size = get_max_llava_image_tokens(ctx)
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
seq_data, ranges = dummy_seq_data_for_clip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
@@ -120,9 +121,9 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_clip(vision_config, num_images)
|
||||
return seq_data, mm_data
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
seq_data = dummy_seq_data_for_siglip(
|
||||
seq_data, ranges = dummy_seq_data_for_siglip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
@@ -131,9 +132,9 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_siglip(vision_config, num_images)
|
||||
return seq_data, mm_data
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
elif isinstance(vision_config, PixtralVisionConfig):
|
||||
seq_data = dummy_seq_data_for_pixtral_hf(
|
||||
seq_data, ranges = dummy_seq_data_for_pixtral_hf(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
@@ -142,7 +143,7 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
|
||||
return seq_data, mm_data
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
Reference in New Issue
Block a user