[Core] Prevent side-channel attacks via cache salting (#17045)

Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com>
This commit is contained in:
Marko Rosenmueller
2025-04-30 14:27:21 +02:00
committed by GitHub
parent a7d5b016bd
commit 77073c77bc
18 changed files with 328 additions and 126 deletions

View File

@@ -17,7 +17,8 @@ from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs,
PromptType, SingletonInputs, SingletonPrompt, token_inputs)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
from .parse import (ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
is_explicit_encoder_decoder_prompt, parse_singleton_prompt)
logger = init_logger(__name__)
@@ -283,6 +284,29 @@ class InputPreprocessor:
return mm_processor.apply(prompt, mm_data, mm_processor_kwargs,
return_mm_hashes)
def _get_prompt_data(self, parsed_prompt: Union[ParsedStrPrompt,
ParsedTextPrompt,
ParsedTokensPrompt]):
prompt_text = None
prompt_token_ids = None
token_type_ids = None
cache_salt = None
if parsed_prompt["type"] == "str":
prompt_text = parsed_prompt["content"]
else:
cache_salt = parsed_prompt["content"].get("cache_salt")
if parsed_prompt["type"] == "text":
prompt_text = parsed_prompt["content"]["prompt"]
elif parsed_prompt["type"] == "tokens":
prompt_token_ids = parsed_prompt["content"].get(
"prompt_token_ids")
token_type_ids = parsed_prompt["content"].get("token_type_ids")
else:
assert_never(parsed_prompt)
return prompt_text, prompt_token_ids, token_type_ids, cache_salt
def _prompt_to_llm_inputs(
self,
prompt: SingletonPrompt,
@@ -304,70 +328,36 @@ class InputPreprocessor:
* :class:`SingletonInputs` instance
"""
parsed = parse_singleton_prompt(prompt)
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
self._get_prompt_data(parsed)
if parsed["type"] == "str":
prompt_text = parsed["content"]
# If multimodal data is present, process and return immediately
if parsed["type"] != "str" and parsed["content"].get(
"multi_modal_data") is not None:
inputs = self._process_multimodal(
prompt_text if prompt_text is not None else prompt_token_ids,
parsed["content"]["multi_modal_data"],
parsed["content"].get("mm_processor_kwargs"),
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
if cache_salt is not None:
inputs["cache_salt"] = cache_salt
return inputs
if prompt_token_ids is None:
prompt_token_ids = self._tokenize_prompt(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
)
if parsed["type"] == "tokens":
tokens_content = parsed["content"]
prompt_token_ids = tokens_content["prompt_token_ids"]
token_type_ids = tokens_content.get("token_type_ids")
multi_modal_data = tokens_content.get("multi_modal_data")
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
if multi_modal_data is not None:
return self._process_multimodal(
prompt_token_ids,
multi_modal_data,
mm_processor_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
return token_inputs(
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
)
if parsed["type"] == "text":
text_content = parsed["content"]
prompt_text = text_content["prompt"]
multi_modal_data = text_content.get("multi_modal_data")
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
if multi_modal_data is not None:
return self._process_multimodal(
prompt_text,
multi_modal_data,
mm_processor_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
prompt_token_ids = self._tokenize_prompt(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
)
assert_never(parsed)
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
cache_salt=cache_salt,
)
async def _prompt_to_llm_inputs_async(
self,
@@ -379,64 +369,35 @@ class InputPreprocessor:
"""Async version of :meth:`_extract_prompt_components`."""
parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "str":
prompt_text = parsed["content"]
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
self._get_prompt_data(parsed)
if parsed["type"] != "str" and parsed["content"].get(
"multi_modal_data") is not None:
inputs = await self._process_multimodal_async(
prompt_token_ids if prompt_text is None else prompt_text,
parsed["content"]["multi_modal_data"],
parsed["content"].get("mm_processor_kwargs"),
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
if cache_salt is not None:
inputs["cache_salt"] = cache_salt
return inputs
if prompt_token_ids is None:
prompt_token_ids = await self._tokenize_prompt_async(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
)
if parsed["type"] == "tokens":
tokens_content = parsed["content"]
prompt_token_ids = tokens_content["prompt_token_ids"]
multi_modal_data = tokens_content.get("multi_modal_data")
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
if multi_modal_data is not None:
return await self._process_multimodal_async(
prompt_token_ids,
multi_modal_data,
mm_processor_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
return token_inputs(prompt_token_ids=prompt_token_ids)
if parsed["type"] == "text":
text_content = parsed["content"]
prompt_text = text_content["prompt"]
multi_modal_data = text_content.get("multi_modal_data")
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
if multi_modal_data is not None:
return await self._process_multimodal_async(
prompt_text,
multi_modal_data,
mm_processor_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
prompt_token_ids = await self._tokenize_prompt_async(
prompt_text,
lora_request=lora_request,
)
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
)
assert_never(parsed)
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
cache_salt=cache_salt,
)
def _build_enc_dec_llm_inputs(
self,
@@ -516,6 +477,11 @@ class InputPreprocessor:
mm_hashes=inputs["mm_hashes"],
mm_placeholders=inputs["mm_placeholders"],
)
cache_salt = inputs.get("cache_salt")
if cache_salt is not None:
decoder_inputs["cache_salt"] = cache_salt
elif inputs["type"] == "token":
# Text-only inputs
encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])