[Core][Multimodal] Allow passing multi_modal_uuids as multimodal identifiers. (#23394)
Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -150,6 +150,49 @@ class Processor:
|
||||
self._validate_sampling_params(params, lora_request)
|
||||
self._validate_supported_sampling_params(params)
|
||||
|
||||
def _validate_multi_modal_uuids(self, prompt: PromptType) -> None:
|
||||
"""
|
||||
Validate that user-provided multi_modal_uuids align with
|
||||
multi_modal_data in the incoming request prompt(s).
|
||||
Only checks lengths; `None` entries are allowed and will be
|
||||
auto-hashed downstream.
|
||||
"""
|
||||
|
||||
def _validate_single_prompt(single_prompt: Union[dict, str]) -> None:
|
||||
if not isinstance(single_prompt, dict):
|
||||
return
|
||||
mm_data = single_prompt.get("multi_modal_data")
|
||||
mm_uuids = single_prompt.get("multi_modal_uuids")
|
||||
if not mm_data or not mm_uuids:
|
||||
return
|
||||
|
||||
for modality, items in mm_data.items():
|
||||
if modality in mm_uuids:
|
||||
data_len = len(items) if isinstance(items, list) else 1
|
||||
uuid_len = len(mm_uuids[modality]) if isinstance(
|
||||
mm_uuids[modality], list) else 1
|
||||
if uuid_len != data_len:
|
||||
raise ValueError(
|
||||
f"multi_modal_uuids for modality '{modality}' "
|
||||
"must have same length as data: got "
|
||||
f"{uuid_len} uuids vs "
|
||||
f"{data_len} items.")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"multi_modal_uuids for modality '{modality}' must "
|
||||
"be provided if multi_modal_data is provided.")
|
||||
|
||||
# Handle explicit encoder/decoder prompts or singleton prompt
|
||||
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
|
||||
enc = prompt.get("encoder_prompt")
|
||||
dec = prompt.get("decoder_prompt")
|
||||
if enc is not None:
|
||||
_validate_single_prompt(enc)
|
||||
if dec is not None:
|
||||
_validate_single_prompt(dec)
|
||||
else:
|
||||
_validate_single_prompt(prompt) # type: ignore[arg-type]
|
||||
|
||||
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
@@ -289,17 +332,27 @@ class Processor:
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
# Optionally generate multimodal hash overrides based on request id.
|
||||
# Optionally generate multimodal hash overrides to avoid hashing
|
||||
# multimodal data items by their content as their identifiers.
|
||||
|
||||
# NOTE: when users explicitly turn off BOTH prefix caching and input
|
||||
# processing caching, no multimodal features or embeddings will be
|
||||
# reused across requests, therefore hashing is no longer necessary.
|
||||
# reused across requests, therefore identifying multimodal data items
|
||||
# by their content is no longer necessary, and we create uuids with
|
||||
# request id-modality-index as multimodal hash overrides.
|
||||
if (self.model_config.multimodal_config and
|
||||
self.model_config.multimodal_config.mm_processor_cache_gb == 0
|
||||
and not self.cache_config.enable_prefix_caching):
|
||||
mm_hash_overrides = self._maybe_build_mm_hash_overrides(
|
||||
request_id, prompt)
|
||||
else:
|
||||
mm_hash_overrides = None
|
||||
# Otherwise, use user-provided uuids as multimodal hash overrides
|
||||
# if provided.
|
||||
self._validate_multi_modal_uuids(prompt)
|
||||
if isinstance(prompt, dict):
|
||||
mm_hash_overrides = prompt.get("multi_modal_uuids")
|
||||
else:
|
||||
mm_hash_overrides = None
|
||||
|
||||
# Process inputs, which includes:
|
||||
# 1. Tokenize text prompt, with LoRA request if one exists.
|
||||
@@ -317,6 +370,7 @@ class Processor:
|
||||
params=params,
|
||||
processed_inputs=processed_inputs,
|
||||
)
|
||||
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
||||
|
||||
self._validate_model_inputs(processed_inputs, lora_request)
|
||||
|
||||
Reference in New Issue
Block a user