[Multimodal] Generate mm_hash based on request metadata when caching is turned off (#23690)

Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Roger Wang
2025-08-27 13:24:31 -07:00
committed by GitHub
parent 0585a9e73c
commit 8bf6266a17
12 changed files with 179 additions and 24 deletions

View File

@@ -257,6 +257,8 @@ class InputPreprocessor:
mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs:
"""
Apply the model's multi-modal processor to a multi-modal prompt,
@@ -273,10 +275,13 @@ class InputPreprocessor:
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return mm_processor.apply(prompt,
mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs)
return mm_processor.apply(
prompt,
mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
async def _process_multimodal_async(
self,
@@ -285,6 +290,8 @@ class InputPreprocessor:
mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs:
"""
Async version of
@@ -301,10 +308,13 @@ class InputPreprocessor:
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return mm_processor.apply(prompt,
mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs)
return mm_processor.apply(
prompt,
mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
def _process_embeds(
self,
@@ -341,6 +351,8 @@ class InputPreprocessor:
parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"]
token_type_ids = parsed_content.get("token_type_ids")
@@ -353,6 +365,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
else:
inputs = token_inputs(
@@ -370,6 +383,8 @@ class InputPreprocessor:
parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"]
token_type_ids = parsed_content.get("token_type_ids")
@@ -382,6 +397,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
else:
inputs = token_inputs(
@@ -399,6 +415,8 @@ class InputPreprocessor:
parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"]
@@ -410,6 +428,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
else:
prompt_token_ids = self._tokenize_prompt(
@@ -432,6 +451,8 @@ class InputPreprocessor:
parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"]
@@ -443,6 +464,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
else:
prompt_token_ids = await self._tokenize_prompt_async(
@@ -465,6 +487,8 @@ class InputPreprocessor:
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> SingletonInputs:
"""
Extract the singleton inputs from a prompt.
@@ -486,18 +510,21 @@ class InputPreprocessor:
return self._process_tokens(
parsed["content"],
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
if parsed["type"] == "text":
return self._process_text(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
if parsed["type"] == "str":
return self._process_text(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
assert_never(parsed)
@@ -507,6 +534,8 @@ class InputPreprocessor:
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> SingletonInputs:
"""
Async version of
@@ -520,18 +549,21 @@ class InputPreprocessor:
return await self._process_tokens_async(
parsed["content"],
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
if parsed["type"] == "text":
return await self._process_text_async(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
if parsed["type"] == "str":
return await self._process_text_async(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
assert_never(parsed)
@@ -641,6 +673,8 @@ class InputPreprocessor:
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> EncoderDecoderInputs:
"""
For encoder/decoder models only:
@@ -682,6 +716,7 @@ class InputPreprocessor:
encoder_inputs = self._prompt_to_llm_inputs(
prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_inputs = None
@@ -697,6 +732,7 @@ class InputPreprocessor:
inputs = self._prompt_to_llm_inputs(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
@@ -712,6 +748,8 @@ class InputPreprocessor:
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> EncoderDecoderInputs:
"""
Async version of
@@ -724,6 +762,7 @@ class InputPreprocessor:
encoder_task = self._prompt_to_llm_inputs_async(
prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
@@ -733,6 +772,7 @@ class InputPreprocessor:
decoder_task = self._prompt_to_llm_inputs_async(
decoder_input,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
encoder_inputs, decoder_inputs = await asyncio.gather(
@@ -748,6 +788,7 @@ class InputPreprocessor:
inputs = await self._prompt_to_llm_inputs_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
@@ -774,6 +815,8 @@ class InputPreprocessor:
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> DecoderOnlyInputs:
"""
For decoder-only models:
@@ -794,6 +837,7 @@ class InputPreprocessor:
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
return self._build_decoder_only_llm_inputs(prompt_comps)
@@ -803,6 +847,8 @@ class InputPreprocessor:
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> DecoderOnlyInputs:
"""
Async version of
@@ -812,6 +858,7 @@ class InputPreprocessor:
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
return self._build_decoder_only_llm_inputs(prompt_comps)
@@ -821,6 +868,8 @@ class InputPreprocessor:
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> ProcessorInputs:
"""Preprocess the input prompt."""
if self.model_config.is_encoder_decoder:
@@ -829,6 +878,7 @@ class InputPreprocessor:
return self._process_encoder_decoder_prompt(
prompt,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if is_explicit_encoder_decoder_prompt(prompt):
@@ -840,6 +890,7 @@ class InputPreprocessor:
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
async def preprocess_async(
@@ -847,6 +898,8 @@ class InputPreprocessor:
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> ProcessorInputs:
"""
Async version of
@@ -858,6 +911,7 @@ class InputPreprocessor:
return await self._process_encoder_decoder_prompt_async(
prompt,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
if is_explicit_encoder_decoder_prompt(prompt):
@@ -869,6 +923,7 @@ class InputPreprocessor:
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
def clear_cache(self) -> None:

View File

@@ -290,6 +290,7 @@ class DeepseekVL2MultiModalProcessor(
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
# The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is
@@ -301,6 +302,7 @@ class DeepseekVL2MultiModalProcessor(
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
return super()._cached_apply_hf_processor(
@@ -308,6 +310,7 @@ class DeepseekVL2MultiModalProcessor(
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)

View File

@@ -479,6 +479,7 @@ class H2OVLMultiModalProcessor(
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
# The processor logic is different for len(images) <= 1 vs > 1
# Since the processing cache assumes that the processor output is
@@ -490,6 +491,7 @@ class H2OVLMultiModalProcessor(
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
return super()._cached_apply_hf_processor(
@@ -497,6 +499,7 @@ class H2OVLMultiModalProcessor(
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)

View File

@@ -795,6 +795,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
@@ -805,8 +806,11 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
image_height=-1,
)
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
tokenization_kwargs)
result = super().apply(prompt,
mm_data,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides)
mm_items = self._to_mm_items(mm_data)
mm_item_counts = mm_items.get_all_counts()

View File

@@ -184,9 +184,13 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalEncDecInputs:
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
tokenization_kwargs)
mm_inputs = super().apply(prompt,
mm_data,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides)
image_token_id = self.info.get_hf_config().image_token_index
# Check that the number of image tokens in the decoder prompt matches

View File

@@ -203,9 +203,13 @@ class PaliGemmaMultiModalProcessor(
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs:
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
tokenization_kwargs)
mm_inputs = super().apply(prompt,
mm_data,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides)
prompt_token_ids = mm_inputs["prompt_token_ids"]
tokenizer = self.info.get_tokenizer()

View File

@@ -314,12 +314,14 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
# NOTE: The tokens are already inserted by the chat template

View File

@@ -138,6 +138,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs:
if "image" in mm_data:
image_data = mm_data["image"]
@@ -146,8 +147,10 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
mm_data = {"image": mm_data}
mm_items = self._to_mm_items(mm_data)
mm_hashes = self._hash_mm_items(mm_items, hf_processor_mm_kwargs,
tokenization_kwargs or {})
tokenization_kwargs = tokenization_kwargs or {}
mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else
self._hash_mm_items(mm_items, hf_processor_mm_kwargs,
tokenization_kwargs))
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
mm_processed_data = BatchFeature(image_data)

View File

@@ -327,6 +327,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs:
"""
Process multi-modal inputs to be used in vLLM.
@@ -393,9 +394,11 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs,
num_image_patches),
)
# Use overrides if provided; fallback to data-dependent hashing.
mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else
self._hash_mm_items(mm_items, hf_processor_mm_kwargs,
tokenization_kwargs))
mm_hashes = self._hash_mm_items(mm_items, hf_processor_mm_kwargs,
tokenization_kwargs)
return MultiModalInputs(
type="multimodal",
prompt=prompt,

View File

@@ -288,12 +288,14 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
# NOTE: The tokens are already inserted by the chat template

View File

@@ -1020,8 +1020,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt: str,
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
*,
mm_hash_overrides: Optional[MultiModalHashes] = None,
) -> MultiModalInputs:
return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
return self.apply(prompt,
mm_data,
hf_processor_mm_kwargs,
mm_hash_overrides=mm_hash_overrides)
def _get_data_parser(self) -> MultiModalDataParser:
"""
@@ -1357,7 +1362,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> MultiModalHashes:
"""Create MM hashes to be returned (only used in V1)."""
"""Create MM hashes to be returned (only used in V1).
Note: When overrides are provided via callers of `apply`,
`_hash_mm_items` will be bypassed and the overrides will be used.
"""
model_id = self.info.model_id
return {
@@ -1464,6 +1473,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
*,
mm_hash_overrides: Optional[MultiModalHashes] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
(
prompt_ids,
@@ -1483,8 +1494,10 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs),
)
mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
tokenization_kwargs)
# Use overrides if provided; fallback to data-dependent hashing.
mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else
self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
tokenization_kwargs))
mm_prompt_updates = self._get_mm_prompt_updates(
mm_data_items,
@@ -1506,6 +1519,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
*,
mm_hash_overrides: Optional[MultiModalHashes] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
"""
Apply the HF processor on the full prompt text,
@@ -1520,10 +1535,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
tokenization_kwargs)
# Use overrides if provided; fallback to data-dependent hashing.
mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else
self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
tokenization_kwargs))
mm_missing_data_items = self._get_cache_missing_items(
cache=cache,
@@ -1723,6 +1741,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs:
"""
Process multi-modal inputs to be used in vLLM.
@@ -1751,6 +1771,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items,
hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
# NOTE: tokenization_kwargs are not required to init processor
@@ -1835,6 +1856,8 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
*,
mm_hash_overrides: Optional[MultiModalHashes] = None,
) -> MultiModalEncDecInputs:
"""
Process multi-modal inputs to be used in vLLM.
@@ -1849,6 +1872,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
mm_data,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
return self._get_enc_dec_inputs(

View File

@@ -225,6 +225,41 @@ class Processor:
# Remember that this backend was set automatically
params.guided_decoding.backend_was_auto = True
def _maybe_build_mm_hash_overrides(
self,
request_id: str,
prompt: PromptType,
) -> Optional[dict[str, list[str]]]:
"""Build per-item multimodal hash overrides when enabled. In this case,
multimodal data items are identified by their request id, modality and
index rather than their content.
Returns a dictionary of modality -> list[str] of overrides, or None if
disabled or no multimodal data is present.
"""
def _extract_mm_data(p: PromptType):
if isinstance(p, dict) and "encoder_prompt" in p:
enc = p.get("encoder_prompt")
if isinstance(enc, dict):
return enc.get("multi_modal_data")
return None
if isinstance(p, dict):
return p.get("multi_modal_data")
return None
mm_data = _extract_mm_data(prompt)
if not mm_data:
return None
overrides: dict[str, list[str]] = {}
for modality, data in mm_data.items():
n = len(data) if isinstance(data, list) else 1
overrides[modality] = [
f"{request_id}-{modality}-{i}" for i in range(n)
]
return overrides
def process_inputs(
self,
request_id: str,
@@ -254,6 +289,18 @@ class Processor:
if arrival_time is None:
arrival_time = time.time()
# Optionally generate multimodal hash overrides based on request id.
# NOTE: when users explicitly turn off BOTH prefix caching and input
# processing caching, no multimodal features or embeddings will be
# reused across requests, therefore hashing is no longer necessary.
if (self.model_config.multimodal_config and
self.model_config.multimodal_config.mm_processor_cache_gb == 0
and not self.cache_config.enable_prefix_caching):
mm_hash_overrides = self._maybe_build_mm_hash_overrides(
request_id, prompt)
else:
mm_hash_overrides = None
# Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess
@@ -262,6 +309,7 @@ class Processor:
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
from vllm.platforms import current_platform
current_platform.validate_request(