[Core] Prevent side-channel attacks via cache salting (#17045)
Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
a7d5b016bd
commit
77073c77bc
@@ -17,7 +17,8 @@ from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
|
||||
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs,
|
||||
PromptType, SingletonInputs, SingletonPrompt, token_inputs)
|
||||
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
|
||||
from .parse import (ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
|
||||
is_explicit_encoder_decoder_prompt, parse_singleton_prompt)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -283,6 +284,29 @@ class InputPreprocessor:
|
||||
return mm_processor.apply(prompt, mm_data, mm_processor_kwargs,
|
||||
return_mm_hashes)
|
||||
|
||||
def _get_prompt_data(self, parsed_prompt: Union[ParsedStrPrompt,
|
||||
ParsedTextPrompt,
|
||||
ParsedTokensPrompt]):
|
||||
prompt_text = None
|
||||
prompt_token_ids = None
|
||||
token_type_ids = None
|
||||
cache_salt = None
|
||||
|
||||
if parsed_prompt["type"] == "str":
|
||||
prompt_text = parsed_prompt["content"]
|
||||
else:
|
||||
cache_salt = parsed_prompt["content"].get("cache_salt")
|
||||
if parsed_prompt["type"] == "text":
|
||||
prompt_text = parsed_prompt["content"]["prompt"]
|
||||
elif parsed_prompt["type"] == "tokens":
|
||||
prompt_token_ids = parsed_prompt["content"].get(
|
||||
"prompt_token_ids")
|
||||
token_type_ids = parsed_prompt["content"].get("token_type_ids")
|
||||
else:
|
||||
assert_never(parsed_prompt)
|
||||
|
||||
return prompt_text, prompt_token_ids, token_type_ids, cache_salt
|
||||
|
||||
def _prompt_to_llm_inputs(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
@@ -304,70 +328,36 @@ class InputPreprocessor:
|
||||
* :class:`SingletonInputs` instance
|
||||
"""
|
||||
parsed = parse_singleton_prompt(prompt)
|
||||
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
|
||||
self._get_prompt_data(parsed)
|
||||
|
||||
if parsed["type"] == "str":
|
||||
prompt_text = parsed["content"]
|
||||
# If multimodal data is present, process and return immediately
|
||||
if parsed["type"] != "str" and parsed["content"].get(
|
||||
"multi_modal_data") is not None:
|
||||
inputs = self._process_multimodal(
|
||||
prompt_text if prompt_text is not None else prompt_token_ids,
|
||||
parsed["content"]["multi_modal_data"],
|
||||
parsed["content"].get("mm_processor_kwargs"),
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
if cache_salt is not None:
|
||||
inputs["cache_salt"] = cache_salt
|
||||
return inputs
|
||||
|
||||
if prompt_token_ids is None:
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
prompt_text,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
return token_inputs(
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
)
|
||||
|
||||
if parsed["type"] == "tokens":
|
||||
tokens_content = parsed["content"]
|
||||
|
||||
prompt_token_ids = tokens_content["prompt_token_ids"]
|
||||
token_type_ids = tokens_content.get("token_type_ids")
|
||||
multi_modal_data = tokens_content.get("multi_modal_data")
|
||||
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
|
||||
|
||||
if multi_modal_data is not None:
|
||||
return self._process_multimodal(
|
||||
prompt_token_ids,
|
||||
multi_modal_data,
|
||||
mm_processor_kwargs,
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
return token_inputs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
)
|
||||
|
||||
if parsed["type"] == "text":
|
||||
text_content = parsed["content"]
|
||||
|
||||
prompt_text = text_content["prompt"]
|
||||
multi_modal_data = text_content.get("multi_modal_data")
|
||||
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
|
||||
|
||||
if multi_modal_data is not None:
|
||||
return self._process_multimodal(
|
||||
prompt_text,
|
||||
multi_modal_data,
|
||||
mm_processor_kwargs,
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
prompt_text,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
return token_inputs(
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
)
|
||||
|
||||
assert_never(parsed)
|
||||
return token_inputs(
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
|
||||
async def _prompt_to_llm_inputs_async(
|
||||
self,
|
||||
@@ -379,64 +369,35 @@ class InputPreprocessor:
|
||||
"""Async version of :meth:`_extract_prompt_components`."""
|
||||
parsed = parse_singleton_prompt(prompt)
|
||||
|
||||
if parsed["type"] == "str":
|
||||
prompt_text = parsed["content"]
|
||||
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
|
||||
self._get_prompt_data(parsed)
|
||||
|
||||
if parsed["type"] != "str" and parsed["content"].get(
|
||||
"multi_modal_data") is not None:
|
||||
inputs = await self._process_multimodal_async(
|
||||
prompt_token_ids if prompt_text is None else prompt_text,
|
||||
parsed["content"]["multi_modal_data"],
|
||||
parsed["content"].get("mm_processor_kwargs"),
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
if cache_salt is not None:
|
||||
inputs["cache_salt"] = cache_salt
|
||||
return inputs
|
||||
|
||||
if prompt_token_ids is None:
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
prompt_text,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
return token_inputs(
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
)
|
||||
|
||||
if parsed["type"] == "tokens":
|
||||
tokens_content = parsed["content"]
|
||||
|
||||
prompt_token_ids = tokens_content["prompt_token_ids"]
|
||||
multi_modal_data = tokens_content.get("multi_modal_data")
|
||||
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
|
||||
|
||||
if multi_modal_data is not None:
|
||||
return await self._process_multimodal_async(
|
||||
prompt_token_ids,
|
||||
multi_modal_data,
|
||||
mm_processor_kwargs,
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
return token_inputs(prompt_token_ids=prompt_token_ids)
|
||||
|
||||
if parsed["type"] == "text":
|
||||
text_content = parsed["content"]
|
||||
|
||||
prompt_text = text_content["prompt"]
|
||||
multi_modal_data = text_content.get("multi_modal_data")
|
||||
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
|
||||
|
||||
if multi_modal_data is not None:
|
||||
return await self._process_multimodal_async(
|
||||
prompt_text,
|
||||
multi_modal_data,
|
||||
mm_processor_kwargs,
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
prompt_text,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
return token_inputs(
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
)
|
||||
|
||||
assert_never(parsed)
|
||||
return token_inputs(
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
|
||||
def _build_enc_dec_llm_inputs(
|
||||
self,
|
||||
@@ -516,6 +477,11 @@ class InputPreprocessor:
|
||||
mm_hashes=inputs["mm_hashes"],
|
||||
mm_placeholders=inputs["mm_placeholders"],
|
||||
)
|
||||
|
||||
cache_salt = inputs.get("cache_salt")
|
||||
if cache_salt is not None:
|
||||
decoder_inputs["cache_salt"] = cache_salt
|
||||
|
||||
elif inputs["type"] == "token":
|
||||
# Text-only inputs
|
||||
encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
|
||||
|
||||
Reference in New Issue
Block a user