[V1] Extend beyond image modality and support mixed-modality inference with Llava-OneVision (#11685)

Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Roger Wang
2025-01-06 11:58:16 -08:00
committed by GitHub
parent e20c92bb61
commit 91b361ae89
17 changed files with 633 additions and 279 deletions

View File

@@ -19,7 +19,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@@ -82,12 +82,10 @@ class GPUModelRunner:
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY
# NOTE: mm_input_mapper_client and mm_hasher are only used for memory
# profiling.
self.mm_input_mapper_client = MMInputMapperClient(self.model_config)
self.mm_hasher = MMHasher()
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
cache_config.enable_prefix_caching
# NOTE: Initialized input mapper is only used for processing dummy
# multimodal data into multimodal kwargs for GPU memory profiling.
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
self.mm_input_mapper_profiling.use_cache = False
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501
self.encoder_cache_size = self.scheduler_config.encoder_cache_size
@@ -722,8 +720,6 @@ class GPUModelRunner:
]
# Profile with multimodal encoder & encoder cache.
# TODO (ywang96): generalize this beyond image modality since
# mm_input_mapper only supports image inputs.
if self.is_multimodal_model:
# Create dummy batch of multimodal inputs.
@@ -735,15 +731,30 @@ class GPUModelRunner:
dummy_mm_data = dummy_request_data.multi_modal_data
# NOTE: Currently model is profiled with a single non-text
# modality even when it supports multiple.
max_tokens_per_mm_item = max(
self.mm_registry.get_max_tokens_per_item_by_modality(
self.model_config).values())
# modality with the max possible input tokens even when
# it supports multiple.
max_tokens_by_modality_dict = self.mm_registry.get_max_tokens_per_item_by_modality( # noqa: E501
self.model_config)
max_num_mm_items_encoder_budget = min(
self.max_num_encoder_input_tokens,
self.encoder_cache_size) // max_tokens_per_mm_item
dummy_data_modality, max_tokens_per_mm_item = max(
max_tokens_by_modality_dict.items(), key=lambda item: item[1])
# Check how many items of this modality can be supported by
# the encoder cache budget.
encoder_cache_budget = min(self.max_num_encoder_input_tokens,
self.encoder_cache_size)
max_num_mm_items_encoder_budget = encoder_cache_budget // \
max_tokens_per_mm_item
# TODO: Allow users to set encoder_cache_budget in case this
# happens.
assert max_num_mm_items_encoder_budget > 0, (
f"Encoder cache budget={encoder_cache_budget} is too small to "
f"support the maximum possible size of multimodal embeddings"
f"={max_tokens_per_mm_item}.")
# Check how many items of this modality can be supported by
# the decoder budget.
max_mm_items_per_req = max(
self.mm_registry.get_mm_limits_per_prompt(
self.model_config).values())
@@ -763,33 +774,24 @@ class GPUModelRunner:
# they are scheduled to be processed separately.
# Case when models have a merged processor, their dummy data is
# already batched `MultiModalKwargs`, therefore we need to "unbatch"
# and take the first item in each batched tensor.
# TODO (ywang96): This is somewhat hacky. Refactor this to be
# consistent with the other case.
# already batched `MultiModalKwargs`, therefore we take the first
# `MultiModalKwargsItem` from the desired modality to profile on.
if isinstance(dummy_mm_data, MultiModalKwargs):
dummy_mm_kwargs = {
k: v[0].unsqueeze(0)
for k, v in dummy_mm_data.items()
}
dummy_mm_item = dummy_mm_data.get_item(
modality=dummy_data_modality, item_index=0)
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
# Case when models have dummy data explicitly defined as
# `MultiModalDataDict`, so they need to be processed through input
# mapper.
# TODO (ywang96): deprecate this path once merged processor is
# supported on all models.
else:
# Compute MM hashes (if enabled)
mm_hashes = None
if self.use_hash:
mm_hashes = self.mm_hasher.hash_dummy_mm_data(
dummy_mm_data)
mm_kwargs_list = self.mm_input_mapper_client.process_inputs(
mm_kwargs_list = self.mm_input_mapper_profiling.process_inputs(
mm_data=dummy_mm_data,
mm_hashes=mm_hashes,
mm_hashes=None,
mm_processor_kwargs=None,
precomputed_mm_inputs=None)
# Take the first `MultiModalKwargs`
dummy_mm_kwargs = mm_kwargs_list[0]
batched_dummy_mm_inputs = MultiModalKwargs.batch(