[Core] Simplify and unify mm uuid handling & auto-generated mm hash overrides processing. (#24271)
Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
This commit is contained in:
@@ -258,8 +258,7 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> MultiModalInputs:
|
||||
"""
|
||||
Apply the model's multi-modal processor to a multi-modal prompt,
|
||||
@@ -281,7 +280,7 @@ class InputPreprocessor:
|
||||
mm_data,
|
||||
hf_processor_mm_kwargs=mm_processor_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
mm_hashes = mm_input["mm_hashes"]
|
||||
|
||||
@@ -302,8 +301,7 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> MultiModalInputs:
|
||||
"""
|
||||
Async version of
|
||||
@@ -325,7 +323,7 @@ class InputPreprocessor:
|
||||
mm_data,
|
||||
hf_processor_mm_kwargs=mm_processor_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
mm_hashes = mm_input["mm_hashes"]
|
||||
|
||||
@@ -390,8 +388,7 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
prompt_token_ids = self._truncate_inputs(
|
||||
parsed_content["prompt_token_ids"], tokenization_kwargs)
|
||||
@@ -404,7 +401,7 @@ class InputPreprocessor:
|
||||
parsed_content.get("mm_processor_kwargs"),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
else:
|
||||
inputs = token_inputs(prompt_token_ids=prompt_token_ids)
|
||||
@@ -420,8 +417,7 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
prompt_token_ids = self._truncate_inputs(
|
||||
parsed_content["prompt_token_ids"], tokenization_kwargs)
|
||||
@@ -434,7 +430,7 @@ class InputPreprocessor:
|
||||
parsed_content.get("mm_processor_kwargs"),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
else:
|
||||
inputs = token_inputs(prompt_token_ids=prompt_token_ids, )
|
||||
@@ -450,8 +446,7 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
prompt_text = parsed_content["prompt"]
|
||||
|
||||
@@ -463,7 +458,7 @@ class InputPreprocessor:
|
||||
parsed_content.get("mm_processor_kwargs"),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
else:
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
@@ -487,8 +482,7 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
prompt_text = parsed_content["prompt"]
|
||||
|
||||
@@ -500,7 +494,7 @@ class InputPreprocessor:
|
||||
parsed_content.get("mm_processor_kwargs"),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
else:
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
@@ -524,8 +518,7 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> SingletonInputs:
|
||||
"""
|
||||
Extract the singleton inputs from a prompt.
|
||||
@@ -547,21 +540,21 @@ class InputPreprocessor:
|
||||
return self._process_tokens(
|
||||
parsed["content"],
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
if parsed["type"] == "text":
|
||||
return self._process_text(
|
||||
parsed["content"],
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
if parsed["type"] == "str":
|
||||
return self._process_text(
|
||||
TextPrompt(prompt=parsed["content"]),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
assert_never(parsed)
|
||||
@@ -572,8 +565,7 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> SingletonInputs:
|
||||
"""
|
||||
Async version of
|
||||
@@ -587,21 +579,21 @@ class InputPreprocessor:
|
||||
return await self._process_tokens_async(
|
||||
parsed["content"],
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
if parsed["type"] == "text":
|
||||
return await self._process_text_async(
|
||||
parsed["content"],
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
if parsed["type"] == "str":
|
||||
return await self._process_text_async(
|
||||
TextPrompt(prompt=parsed["content"]),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
assert_never(parsed)
|
||||
@@ -712,8 +704,7 @@ class InputPreprocessor:
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> EncoderDecoderInputs:
|
||||
"""
|
||||
For encoder/decoder models only:
|
||||
@@ -755,7 +746,7 @@ class InputPreprocessor:
|
||||
encoder_inputs = self._prompt_to_llm_inputs(
|
||||
prompt["encoder_prompt"],
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
if (decoder_input := prompt["decoder_prompt"]) is None:
|
||||
decoder_inputs = None
|
||||
@@ -771,7 +762,7 @@ class InputPreprocessor:
|
||||
inputs = self._prompt_to_llm_inputs(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
if self.model_config.is_multimodal_model:
|
||||
# Encoder-Decoder Multimodal model
|
||||
@@ -788,8 +779,7 @@ class InputPreprocessor:
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> EncoderDecoderInputs:
|
||||
"""
|
||||
Async version of
|
||||
@@ -802,7 +792,7 @@ class InputPreprocessor:
|
||||
encoder_task = self._prompt_to_llm_inputs_async(
|
||||
prompt["encoder_prompt"],
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
if (decoder_input := prompt["decoder_prompt"]) is None:
|
||||
@@ -812,7 +802,7 @@ class InputPreprocessor:
|
||||
decoder_task = self._prompt_to_llm_inputs_async(
|
||||
decoder_input,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
encoder_inputs, decoder_inputs = await asyncio.gather(
|
||||
@@ -828,7 +818,7 @@ class InputPreprocessor:
|
||||
inputs = await self._prompt_to_llm_inputs_async(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
if self.model_config.is_multimodal_model:
|
||||
# Encoder-Decoder Multimodal model
|
||||
@@ -856,8 +846,7 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> DecoderOnlyInputs:
|
||||
"""
|
||||
For decoder-only models:
|
||||
@@ -878,7 +867,7 @@ class InputPreprocessor:
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
return self._build_decoder_only_llm_inputs(prompt_comps)
|
||||
@@ -889,8 +878,7 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> DecoderOnlyInputs:
|
||||
"""
|
||||
Async version of
|
||||
@@ -900,7 +888,7 @@ class InputPreprocessor:
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
return self._build_decoder_only_llm_inputs(prompt_comps)
|
||||
@@ -911,8 +899,7 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> ProcessorInputs:
|
||||
"""Preprocess the input prompt."""
|
||||
if self.model_config.is_encoder_decoder:
|
||||
@@ -921,7 +908,7 @@ class InputPreprocessor:
|
||||
return self._process_encoder_decoder_prompt(
|
||||
prompt,
|
||||
tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
@@ -933,7 +920,7 @@ class InputPreprocessor:
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
async def preprocess_async(
|
||||
@@ -942,8 +929,7 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> ProcessorInputs:
|
||||
"""
|
||||
Async version of
|
||||
@@ -955,7 +941,7 @@ class InputPreprocessor:
|
||||
return await self._process_encoder_decoder_prompt_async(
|
||||
prompt,
|
||||
tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
@@ -967,7 +953,7 @@ class InputPreprocessor:
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user