[Core][Multimodal] Allow passing multi_modal_uuids as multimodal identifiers. (#23394)

Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Roger Wang
2025-08-30 18:01:22 -07:00
committed by GitHub
parent 5b8077b8ac
commit 749be00a98
10 changed files with 455 additions and 54 deletions

View File

@@ -150,6 +150,49 @@ class Processor:
self._validate_sampling_params(params, lora_request)
self._validate_supported_sampling_params(params)
def _validate_multi_modal_uuids(self, prompt: PromptType) -> None:
"""
Validate that user-provided multi_modal_uuids align with
multi_modal_data in the incoming request prompt(s).
Only checks lengths; `None` entries are allowed and will be
auto-hashed downstream.
"""
def _validate_single_prompt(single_prompt: Union[dict, str]) -> None:
if not isinstance(single_prompt, dict):
return
mm_data = single_prompt.get("multi_modal_data")
mm_uuids = single_prompt.get("multi_modal_uuids")
if not mm_data or not mm_uuids:
return
for modality, items in mm_data.items():
if modality in mm_uuids:
data_len = len(items) if isinstance(items, list) else 1
uuid_len = len(mm_uuids[modality]) if isinstance(
mm_uuids[modality], list) else 1
if uuid_len != data_len:
raise ValueError(
f"multi_modal_uuids for modality '{modality}' "
"must have same length as data: got "
f"{uuid_len} uuids vs "
f"{data_len} items.")
else:
raise ValueError(
f"multi_modal_uuids for modality '{modality}' must "
"be provided if multi_modal_data is provided.")
# Handle explicit encoder/decoder prompts or singleton prompt
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
enc = prompt.get("encoder_prompt")
dec = prompt.get("decoder_prompt")
if enc is not None:
_validate_single_prompt(enc)
if dec is not None:
_validate_single_prompt(dec)
else:
_validate_single_prompt(prompt) # type: ignore[arg-type]
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
@@ -289,17 +332,27 @@ class Processor:
if arrival_time is None:
arrival_time = time.time()
# Optionally generate multimodal hash overrides based on request id.
# Optionally generate multimodal hash overrides to avoid hashing
# multimodal data items by their content as their identifiers.
# NOTE: when users explicitly turn off BOTH prefix caching and input
# processing caching, no multimodal features or embeddings will be
# reused across requests, therefore hashing is no longer necessary.
# reused across requests, therefore identifying multimodal data items
# by their content is no longer necessary, and we create uuids with
# request id-modality-index as multimodal hash overrides.
if (self.model_config.multimodal_config and
self.model_config.multimodal_config.mm_processor_cache_gb == 0
and not self.cache_config.enable_prefix_caching):
mm_hash_overrides = self._maybe_build_mm_hash_overrides(
request_id, prompt)
else:
mm_hash_overrides = None
# Otherwise, use user-provided uuids as multimodal hash overrides
# if provided.
self._validate_multi_modal_uuids(prompt)
if isinstance(prompt, dict):
mm_hash_overrides = prompt.get("multi_modal_uuids")
else:
mm_hash_overrides = None
# Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists.
@@ -317,6 +370,7 @@ class Processor:
params=params,
processed_inputs=processed_inputs,
)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
self._validate_model_inputs(processed_inputs, lora_request)