[Multimodal] Generate mm_hash based on request metadata when caching is turned off (#23690)

Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Roger Wang
2025-08-27 13:24:31 -07:00
committed by GitHub
parent 0585a9e73c
commit 8bf6266a17
12 changed files with 179 additions and 24 deletions

View File

@@ -257,6 +257,8 @@ class InputPreprocessor:
mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs:
"""
Apply the model's multi-modal processor to a multi-modal prompt,
@@ -273,10 +275,13 @@ class InputPreprocessor:
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return mm_processor.apply(prompt,
mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs)
return mm_processor.apply(
prompt,
mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
async def _process_multimodal_async(
self,
@@ -285,6 +290,8 @@ class InputPreprocessor:
mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs:
"""
Async version of
@@ -301,10 +308,13 @@ class InputPreprocessor:
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return mm_processor.apply(prompt,
mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs)
return mm_processor.apply(
prompt,
mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
def _process_embeds(
self,
@@ -341,6 +351,8 @@ class InputPreprocessor:
parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"]
token_type_ids = parsed_content.get("token_type_ids")
@@ -353,6 +365,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
else:
inputs = token_inputs(
@@ -370,6 +383,8 @@ class InputPreprocessor:
parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"]
token_type_ids = parsed_content.get("token_type_ids")
@@ -382,6 +397,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
else:
inputs = token_inputs(
@@ -399,6 +415,8 @@ class InputPreprocessor:
parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"]
@@ -410,6 +428,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
else:
prompt_token_ids = self._tokenize_prompt(
@@ -432,6 +451,8 @@ class InputPreprocessor:
parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"]
@@ -443,6 +464,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
else:
prompt_token_ids = await self._tokenize_prompt_async(
@@ -465,6 +487,8 @@ class InputPreprocessor:
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> SingletonInputs:
"""
Extract the singleton inputs from a prompt.
@@ -486,18 +510,21 @@ class InputPreprocessor:
return self._process_tokens(
parsed["content"],
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
if parsed["type"] == "text":
return self._process_text(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
if parsed["type"] == "str":
return self._process_text(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
assert_never(parsed)
@@ -507,6 +534,8 @@ class InputPreprocessor:
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> SingletonInputs:
"""
Async version of
@@ -520,18 +549,21 @@ class InputPreprocessor:
return await self._process_tokens_async(
parsed["content"],
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
if parsed["type"] == "text":
return await self._process_text_async(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
if parsed["type"] == "str":
return await self._process_text_async(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
assert_never(parsed)
@@ -641,6 +673,8 @@ class InputPreprocessor:
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> EncoderDecoderInputs:
"""
For encoder/decoder models only:
@@ -682,6 +716,7 @@ class InputPreprocessor:
encoder_inputs = self._prompt_to_llm_inputs(
prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_inputs = None
@@ -697,6 +732,7 @@ class InputPreprocessor:
inputs = self._prompt_to_llm_inputs(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
@@ -712,6 +748,8 @@ class InputPreprocessor:
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> EncoderDecoderInputs:
"""
Async version of
@@ -724,6 +762,7 @@ class InputPreprocessor:
encoder_task = self._prompt_to_llm_inputs_async(
prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
@@ -733,6 +772,7 @@ class InputPreprocessor:
decoder_task = self._prompt_to_llm_inputs_async(
decoder_input,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
encoder_inputs, decoder_inputs = await asyncio.gather(
@@ -748,6 +788,7 @@ class InputPreprocessor:
inputs = await self._prompt_to_llm_inputs_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
@@ -774,6 +815,8 @@ class InputPreprocessor:
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> DecoderOnlyInputs:
"""
For decoder-only models:
@@ -794,6 +837,7 @@ class InputPreprocessor:
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
return self._build_decoder_only_llm_inputs(prompt_comps)
@@ -803,6 +847,8 @@ class InputPreprocessor:
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> DecoderOnlyInputs:
"""
Async version of
@@ -812,6 +858,7 @@ class InputPreprocessor:
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
return self._build_decoder_only_llm_inputs(prompt_comps)
@@ -821,6 +868,8 @@ class InputPreprocessor:
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> ProcessorInputs:
"""Preprocess the input prompt."""
if self.model_config.is_encoder_decoder:
@@ -829,6 +878,7 @@ class InputPreprocessor:
return self._process_encoder_decoder_prompt(
prompt,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if is_explicit_encoder_decoder_prompt(prompt):
@@ -840,6 +890,7 @@ class InputPreprocessor:
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
async def preprocess_async(
@@ -847,6 +898,8 @@ class InputPreprocessor:
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> ProcessorInputs:
"""
Async version of
@@ -858,6 +911,7 @@ class InputPreprocessor:
return await self._process_encoder_decoder_prompt_async(
prompt,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if is_explicit_encoder_decoder_prompt(prompt):
@@ -869,6 +923,7 @@ class InputPreprocessor:
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
def clear_cache(self) -> None: