[VLM][Core] Support profiling with multiple multi-modal inputs per prompt (#7126)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import itertools
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -9,8 +10,7 @@ from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@@ -88,9 +88,11 @@ def get_max_llava_image_tokens(ctx: InputContext):
|
||||
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||
|
||||
|
||||
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
|
||||
def dummy_data_for_llava(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
image_feature_size = get_max_llava_image_tokens(ctx)
|
||||
|
||||
@@ -98,21 +100,23 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int):
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_clip(vision_config)
|
||||
mm_data = dummy_image_for_clip(vision_config, num_images)
|
||||
return seq_data, mm_data
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
seq_data = dummy_seq_data_for_siglip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_siglip(vision_config)
|
||||
mm_data = dummy_image_for_siglip(vision_config, num_images)
|
||||
return seq_data, mm_data
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
|
||||
Reference in New Issue
Block a user