[Multimodal] Generate mm_hash based on request metadata when caching is turned off (#23690)

Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Roger Wang
2025-08-27 13:24:31 -07:00
committed by GitHub
parent 0585a9e73c
commit 8bf6266a17
12 changed files with 179 additions and 24 deletions

View File

@@ -257,6 +257,8 @@ class InputPreprocessor:
mm_processor_kwargs: Optional[Mapping[str, object]], mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Apply the model's multi-modal processor to a multi-modal prompt, Apply the model's multi-modal processor to a multi-modal prompt,
@@ -273,10 +275,13 @@ class InputPreprocessor:
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
mm_processor_kwargs = {} mm_processor_kwargs = {}
return mm_processor.apply(prompt, return mm_processor.apply(
mm_data, prompt,
hf_processor_mm_kwargs=mm_processor_kwargs, mm_data,
tokenization_kwargs=tokenization_kwargs) hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
async def _process_multimodal_async( async def _process_multimodal_async(
self, self,
@@ -285,6 +290,8 @@ class InputPreprocessor:
mm_processor_kwargs: Optional[Mapping[str, object]], mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Async version of Async version of
@@ -301,10 +308,13 @@ class InputPreprocessor:
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
mm_processor_kwargs = {} mm_processor_kwargs = {}
return mm_processor.apply(prompt, return mm_processor.apply(
mm_data, prompt,
hf_processor_mm_kwargs=mm_processor_kwargs, mm_data,
tokenization_kwargs=tokenization_kwargs) hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
def _process_embeds( def _process_embeds(
self, self,
@@ -341,6 +351,8 @@ class InputPreprocessor:
parsed_content: TokensPrompt, parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"] prompt_token_ids = parsed_content["prompt_token_ids"]
token_type_ids = parsed_content.get("token_type_ids") token_type_ids = parsed_content.get("token_type_ids")
@@ -353,6 +365,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
else: else:
inputs = token_inputs( inputs = token_inputs(
@@ -370,6 +383,8 @@ class InputPreprocessor:
parsed_content: TokensPrompt, parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"] prompt_token_ids = parsed_content["prompt_token_ids"]
token_type_ids = parsed_content.get("token_type_ids") token_type_ids = parsed_content.get("token_type_ids")
@@ -382,6 +397,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
else: else:
inputs = token_inputs( inputs = token_inputs(
@@ -399,6 +415,8 @@ class InputPreprocessor:
parsed_content: TextPrompt, parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"] prompt_text = parsed_content["prompt"]
@@ -410,6 +428,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
else: else:
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
@@ -432,6 +451,8 @@ class InputPreprocessor:
parsed_content: TextPrompt, parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"] prompt_text = parsed_content["prompt"]
@@ -443,6 +464,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
else: else:
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
@@ -465,6 +487,8 @@ class InputPreprocessor:
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> SingletonInputs: ) -> SingletonInputs:
""" """
Extract the singleton inputs from a prompt. Extract the singleton inputs from a prompt.
@@ -486,18 +510,21 @@ class InputPreprocessor:
return self._process_tokens( return self._process_tokens(
parsed["content"], parsed["content"],
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
if parsed["type"] == "text": if parsed["type"] == "text":
return self._process_text( return self._process_text(
parsed["content"], parsed["content"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
if parsed["type"] == "str": if parsed["type"] == "str":
return self._process_text( return self._process_text(
TextPrompt(prompt=parsed["content"]), TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
assert_never(parsed) assert_never(parsed)
@@ -507,6 +534,8 @@ class InputPreprocessor:
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> SingletonInputs: ) -> SingletonInputs:
""" """
Async version of Async version of
@@ -520,18 +549,21 @@ class InputPreprocessor:
return await self._process_tokens_async( return await self._process_tokens_async(
parsed["content"], parsed["content"],
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
if parsed["type"] == "text": if parsed["type"] == "text":
return await self._process_text_async( return await self._process_text_async(
parsed["content"], parsed["content"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
if parsed["type"] == "str": if parsed["type"] == "str":
return await self._process_text_async( return await self._process_text_async(
TextPrompt(prompt=parsed["content"]), TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
assert_never(parsed) assert_never(parsed)
@@ -641,6 +673,8 @@ class InputPreprocessor:
self, self,
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
""" """
For encoder/decoder models only: For encoder/decoder models only:
@@ -682,6 +716,7 @@ class InputPreprocessor:
encoder_inputs = self._prompt_to_llm_inputs( encoder_inputs = self._prompt_to_llm_inputs(
prompt["encoder_prompt"], prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_inputs = None decoder_inputs = None
@@ -697,6 +732,7 @@ class InputPreprocessor:
inputs = self._prompt_to_llm_inputs( inputs = self._prompt_to_llm_inputs(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
if self.model_config.is_multimodal_model: if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model # Encoder-Decoder Multimodal model
@@ -712,6 +748,8 @@ class InputPreprocessor:
self, self,
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
""" """
Async version of Async version of
@@ -724,6 +762,7 @@ class InputPreprocessor:
encoder_task = self._prompt_to_llm_inputs_async( encoder_task = self._prompt_to_llm_inputs_async(
prompt["encoder_prompt"], prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
@@ -733,6 +772,7 @@ class InputPreprocessor:
decoder_task = self._prompt_to_llm_inputs_async( decoder_task = self._prompt_to_llm_inputs_async(
decoder_input, decoder_input,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
encoder_inputs, decoder_inputs = await asyncio.gather( encoder_inputs, decoder_inputs = await asyncio.gather(
@@ -748,6 +788,7 @@ class InputPreprocessor:
inputs = await self._prompt_to_llm_inputs_async( inputs = await self._prompt_to_llm_inputs_async(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
if self.model_config.is_multimodal_model: if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model # Encoder-Decoder Multimodal model
@@ -774,6 +815,8 @@ class InputPreprocessor:
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
""" """
For decoder-only models: For decoder-only models:
@@ -794,6 +837,7 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
return self._build_decoder_only_llm_inputs(prompt_comps) return self._build_decoder_only_llm_inputs(prompt_comps)
@@ -803,6 +847,8 @@ class InputPreprocessor:
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
""" """
Async version of Async version of
@@ -812,6 +858,7 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
return self._build_decoder_only_llm_inputs(prompt_comps) return self._build_decoder_only_llm_inputs(prompt_comps)
@@ -821,6 +868,8 @@ class InputPreprocessor:
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
"""Preprocess the input prompt.""" """Preprocess the input prompt."""
if self.model_config.is_encoder_decoder: if self.model_config.is_encoder_decoder:
@@ -829,6 +878,7 @@ class InputPreprocessor:
return self._process_encoder_decoder_prompt( return self._process_encoder_decoder_prompt(
prompt, prompt,
tokenization_kwargs, tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
@@ -840,6 +890,7 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
async def preprocess_async( async def preprocess_async(
@@ -847,6 +898,8 @@ class InputPreprocessor:
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
""" """
Async version of Async version of
@@ -858,6 +911,7 @@ class InputPreprocessor:
return await self._process_encoder_decoder_prompt_async( return await self._process_encoder_decoder_prompt_async(
prompt, prompt,
tokenization_kwargs, tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
@@ -869,6 +923,7 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
def clear_cache(self) -> None: def clear_cache(self) -> None:

View File

@@ -290,6 +290,7 @@ class DeepseekVL2MultiModalProcessor(
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
# The processor logic is different for len(images) <= 2 vs > 2 # The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is # Since the processing cache assumes that the processor output is
@@ -301,6 +302,7 @@ class DeepseekVL2MultiModalProcessor(
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
return super()._cached_apply_hf_processor( return super()._cached_apply_hf_processor(
@@ -308,6 +310,7 @@ class DeepseekVL2MultiModalProcessor(
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )

View File

@@ -479,6 +479,7 @@ class H2OVLMultiModalProcessor(
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
# The processor logic is different for len(images) <= 1 vs > 1 # The processor logic is different for len(images) <= 1 vs > 1
# Since the processing cache assumes that the processor output is # Since the processing cache assumes that the processor output is
@@ -490,6 +491,7 @@ class H2OVLMultiModalProcessor(
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
return super()._cached_apply_hf_processor( return super()._cached_apply_hf_processor(
@@ -497,6 +499,7 @@ class H2OVLMultiModalProcessor(
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )

View File

@@ -795,6 +795,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None, tokenization_kwargs: Optional[Mapping[str, object]] = None,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index image_token_id = hf_config.image_token_index
@@ -805,8 +806,11 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
image_height=-1, image_height=-1,
) )
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs, result = super().apply(prompt,
tokenization_kwargs) mm_data,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides)
mm_items = self._to_mm_items(mm_data) mm_items = self._to_mm_items(mm_data)
mm_item_counts = mm_items.get_all_counts() mm_item_counts = mm_items.get_all_counts()

View File

@@ -184,9 +184,13 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None, tokenization_kwargs: Optional[Mapping[str, object]] = None,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalEncDecInputs: ) -> MultiModalEncDecInputs:
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, mm_inputs = super().apply(prompt,
tokenization_kwargs) mm_data,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides)
image_token_id = self.info.get_hf_config().image_token_index image_token_id = self.info.get_hf_config().image_token_index
# Check that the number of image tokens in the decoder prompt matches # Check that the number of image tokens in the decoder prompt matches

View File

@@ -203,9 +203,13 @@ class PaliGemmaMultiModalProcessor(
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None, tokenization_kwargs: Optional[Mapping[str, object]] = None,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, mm_inputs = super().apply(prompt,
tokenization_kwargs) mm_data,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides)
prompt_token_ids = mm_inputs["prompt_token_ids"] prompt_token_ids = mm_inputs["prompt_token_ids"]
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()

View File

@@ -314,12 +314,14 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
prompt=prompt, prompt=prompt,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
# NOTE: The tokens are already inserted by the chat template # NOTE: The tokens are already inserted by the chat template

View File

@@ -138,6 +138,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None, tokenization_kwargs: Optional[Mapping[str, object]] = None,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
if "image" in mm_data: if "image" in mm_data:
image_data = mm_data["image"] image_data = mm_data["image"]
@@ -146,8 +147,10 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
mm_data = {"image": mm_data} mm_data = {"image": mm_data}
mm_items = self._to_mm_items(mm_data) mm_items = self._to_mm_items(mm_data)
mm_hashes = self._hash_mm_items(mm_items, hf_processor_mm_kwargs, tokenization_kwargs = tokenization_kwargs or {}
tokenization_kwargs or {}) mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else
self._hash_mm_items(mm_items, hf_processor_mm_kwargs,
tokenization_kwargs))
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]} mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
mm_processed_data = BatchFeature(image_data) mm_processed_data = BatchFeature(image_data)

View File

@@ -327,6 +327,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None, tokenization_kwargs: Optional[Mapping[str, object]] = None,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Process multi-modal inputs to be used in vLLM. Process multi-modal inputs to be used in vLLM.
@@ -393,9 +394,11 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs, self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs,
num_image_patches), num_image_patches),
) )
# Use overrides if provided; fallback to data-dependent hashing.
mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else
self._hash_mm_items(mm_items, hf_processor_mm_kwargs,
tokenization_kwargs))
mm_hashes = self._hash_mm_items(mm_items, hf_processor_mm_kwargs,
tokenization_kwargs)
return MultiModalInputs( return MultiModalInputs(
type="multimodal", type="multimodal",
prompt=prompt, prompt=prompt,

View File

@@ -288,12 +288,14 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
prompt=prompt, prompt=prompt,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
# NOTE: The tokens are already inserted by the chat template # NOTE: The tokens are already inserted by the chat template

View File

@@ -1020,8 +1020,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt: str, prompt: str,
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
*,
mm_hash_overrides: Optional[MultiModalHashes] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
return self.apply(prompt, mm_data, hf_processor_mm_kwargs) return self.apply(prompt,
mm_data,
hf_processor_mm_kwargs,
mm_hash_overrides=mm_hash_overrides)
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
""" """
@@ -1357,7 +1362,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
) -> MultiModalHashes: ) -> MultiModalHashes:
"""Create MM hashes to be returned (only used in V1).""" """Create MM hashes to be returned (only used in V1).
Note: When overrides are provided via callers of `apply`,
`_hash_mm_items` will be bypassed and the overrides will be used.
"""
model_id = self.info.model_id model_id = self.info.model_id
return { return {
@@ -1464,6 +1473,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*,
mm_hash_overrides: Optional[MultiModalHashes] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
( (
prompt_ids, prompt_ids,
@@ -1483,8 +1494,10 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs), hf_processor_mm_kwargs),
) )
mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, # Use overrides if provided; fallback to data-dependent hashing.
tokenization_kwargs) mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else
self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
tokenization_kwargs))
mm_prompt_updates = self._get_mm_prompt_updates( mm_prompt_updates = self._get_mm_prompt_updates(
mm_data_items, mm_data_items,
@@ -1506,6 +1519,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*,
mm_hash_overrides: Optional[MultiModalHashes] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
""" """
Apply the HF processor on the full prompt text, Apply the HF processor on the full prompt text,
@@ -1520,10 +1535,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, # Use overrides if provided; fallback to data-dependent hashing.
tokenization_kwargs) mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else
self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
tokenization_kwargs))
mm_missing_data_items = self._get_cache_missing_items( mm_missing_data_items = self._get_cache_missing_items(
cache=cache, cache=cache,
@@ -1723,6 +1741,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None, tokenization_kwargs: Optional[Mapping[str, object]] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Process multi-modal inputs to be used in vLLM. Process multi-modal inputs to be used in vLLM.
@@ -1751,6 +1771,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items, mm_items,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
# NOTE: tokenization_kwargs are not required to init processor # NOTE: tokenization_kwargs are not required to init processor
@@ -1835,6 +1856,8 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None, tokenization_kwargs: Optional[Mapping[str, object]] = None,
*,
mm_hash_overrides: Optional[MultiModalHashes] = None,
) -> MultiModalEncDecInputs: ) -> MultiModalEncDecInputs:
""" """
Process multi-modal inputs to be used in vLLM. Process multi-modal inputs to be used in vLLM.
@@ -1849,6 +1872,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
mm_data, mm_data,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
tokenization_kwargs, tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
return self._get_enc_dec_inputs( return self._get_enc_dec_inputs(

View File

@@ -225,6 +225,41 @@ class Processor:
# Remember that this backend was set automatically # Remember that this backend was set automatically
params.guided_decoding.backend_was_auto = True params.guided_decoding.backend_was_auto = True
def _maybe_build_mm_hash_overrides(
self,
request_id: str,
prompt: PromptType,
) -> Optional[dict[str, list[str]]]:
"""Build per-item multimodal hash overrides when enabled. In this case,
multimodal data items are identified by their request id, modality and
index rather than their content.
Returns a dictionary of modality -> list[str] of overrides, or None if
disabled or no multimodal data is present.
"""
def _extract_mm_data(p: PromptType):
if isinstance(p, dict) and "encoder_prompt" in p:
enc = p.get("encoder_prompt")
if isinstance(enc, dict):
return enc.get("multi_modal_data")
return None
if isinstance(p, dict):
return p.get("multi_modal_data")
return None
mm_data = _extract_mm_data(prompt)
if not mm_data:
return None
overrides: dict[str, list[str]] = {}
for modality, data in mm_data.items():
n = len(data) if isinstance(data, list) else 1
overrides[modality] = [
f"{request_id}-{modality}-{i}" for i in range(n)
]
return overrides
def process_inputs( def process_inputs(
self, self,
request_id: str, request_id: str,
@@ -254,6 +289,18 @@ class Processor:
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
# Optionally generate multimodal hash overrides based on request id.
# NOTE: when users explicitly turn off BOTH prefix caching and input
# processing caching, no multimodal features or embeddings will be
# reused across requests, therefore hashing is no longer necessary.
if (self.model_config.multimodal_config and
self.model_config.multimodal_config.mm_processor_cache_gb == 0
and not self.cache_config.enable_prefix_caching):
mm_hash_overrides = self._maybe_build_mm_hash_overrides(
request_id, prompt)
else:
mm_hash_overrides = None
# Process inputs, which includes: # Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists. # 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess # 2. For multimodal models with a merged preprocessor, preprocess
@@ -262,6 +309,7 @@ class Processor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
current_platform.validate_request( current_platform.validate_request(