[Frontend] Use new Renderer for Completions and Tokenize API (#32863)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -15,7 +15,8 @@ from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.models.protocol import BaseModelPath
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.renderers.hf import HfRenderer
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
@@ -57,6 +58,15 @@ class MockModelConfig:
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
|
||||
def _build_renderer(model_config: MockModelConfig):
|
||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
||||
|
||||
return HfRenderer(
|
||||
model_config,
|
||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||
)
|
||||
|
||||
|
||||
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
models = OpenAIServingModels(
|
||||
engine_client=engine,
|
||||
@@ -71,18 +81,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
chat_template_content_format="auto",
|
||||
)
|
||||
|
||||
async def _fake_process_inputs(
|
||||
request_id,
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
*,
|
||||
lora_request,
|
||||
trace_headers,
|
||||
priority,
|
||||
data_parallel_rank,
|
||||
):
|
||||
return dict(engine_prompt), {}
|
||||
|
||||
async def _fake_preprocess_chat(*args, **kwargs):
|
||||
# return conversation, engine_prompts
|
||||
return (
|
||||
@@ -90,7 +88,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
[{"prompt_token_ids": [1, 2, 3]}],
|
||||
)
|
||||
|
||||
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
|
||||
serving_chat._preprocess_chat = AsyncMock(side_effect=_fake_preprocess_chat)
|
||||
return serving_chat
|
||||
|
||||
@@ -99,11 +96,11 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
async def test_chat_error_non_stream():
|
||||
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
@@ -153,11 +150,11 @@ async def test_chat_error_non_stream():
|
||||
async def test_chat_error_stream():
|
||||
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -15,7 +15,8 @@ from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.models.protocol import BaseModelPath
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.renderers.hf import HfRenderer
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
@@ -61,37 +62,31 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
|
||||
engine_client=engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
)
|
||||
serving_completion = OpenAIServingCompletion(
|
||||
return OpenAIServingCompletion(
|
||||
engine,
|
||||
models,
|
||||
request_logger=None,
|
||||
)
|
||||
|
||||
async def _fake_process_inputs(
|
||||
request_id,
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
*,
|
||||
lora_request,
|
||||
trace_headers,
|
||||
priority,
|
||||
data_parallel_rank,
|
||||
):
|
||||
return dict(engine_prompt), {}
|
||||
|
||||
serving_completion._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
|
||||
return serving_completion
|
||||
def _build_renderer(model_config: MockModelConfig):
|
||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
||||
|
||||
return HfRenderer(
|
||||
model_config,
|
||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_error_non_stream():
|
||||
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_completion = _build_serving_completion(mock_engine)
|
||||
|
||||
@@ -141,11 +136,11 @@ async def test_completion_error_non_stream():
|
||||
async def test_completion_error_stream():
|
||||
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_completion = _build_serving_completion(mock_engine)
|
||||
|
||||
|
||||
@@ -110,7 +110,7 @@ async def test_completions_with_prompt_embeds(
|
||||
# Test case: Single prompt embeds input
|
||||
completion = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
prompt=None,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": encoded_embeds},
|
||||
@@ -121,7 +121,7 @@ async def test_completions_with_prompt_embeds(
|
||||
# Test case: batch completion with prompt_embeds
|
||||
completion = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
prompt=None,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]},
|
||||
@@ -133,7 +133,7 @@ async def test_completions_with_prompt_embeds(
|
||||
# Test case: streaming with prompt_embeds
|
||||
single_completion = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
prompt=None,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": encoded_embeds},
|
||||
@@ -142,7 +142,7 @@ async def test_completions_with_prompt_embeds(
|
||||
|
||||
stream = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
prompt=None,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
@@ -162,7 +162,7 @@ async def test_completions_with_prompt_embeds(
|
||||
# Test case: batch streaming with prompt_embeds
|
||||
stream = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
prompt=None,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
@@ -197,7 +197,7 @@ async def test_completions_with_prompt_embeds(
|
||||
)
|
||||
completion_embeds_only = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="",
|
||||
prompt=None,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": encoded_embeds},
|
||||
@@ -215,7 +215,7 @@ async def test_completions_errors_with_prompt_embeds(
|
||||
# Test error case: invalid prompt_embeds
|
||||
with pytest.raises(BadRequestError):
|
||||
await client_with_prompt_embeds.completions.create(
|
||||
prompt="",
|
||||
prompt=None,
|
||||
model=model_name,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
@@ -237,7 +237,7 @@ async def test_completions_with_logprobs_and_prompt_embeds(
|
||||
# Test case: Logprobs using prompt_embeds
|
||||
completion = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
prompt=None,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
echo=False,
|
||||
@@ -257,7 +257,7 @@ async def test_completions_with_logprobs_and_prompt_embeds(
|
||||
# Test case: Log probs with batch completion and prompt_embeds
|
||||
completion = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
prompt=None,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
echo=False,
|
||||
@@ -287,7 +287,7 @@ async def test_prompt_logprobs_raises_error(
|
||||
with pytest.raises(BadRequestError, match="not compatible"):
|
||||
await client_with_prompt_embeds.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="",
|
||||
prompt=None,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": encoded_embeds, "prompt_logprobs": True},
|
||||
|
||||
@@ -7,7 +7,7 @@ Tests verify that embeddings with correct ndim but incorrect hidden_size
|
||||
are rejected before they can cause crashes during model inference.
|
||||
|
||||
Validation is performed by the parser (MultiModalDataParser) and EmbeddingItems
|
||||
classes, not by CompletionRenderer or MediaIO classes.
|
||||
classes, not by MediaIO classes.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -16,7 +16,8 @@ from vllm.entrypoints.openai.models.protocol import BaseModelPath
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.renderers.hf import HfRenderer
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
@@ -35,6 +36,7 @@ class MockModelConfig:
|
||||
"""Minimal mock ModelConfig for testing."""
|
||||
|
||||
model: str = MODEL_NAME
|
||||
runner_type = "generate"
|
||||
tokenizer: str = MODEL_NAME
|
||||
trust_remote_code: bool = False
|
||||
tokenizer_mode: str = "auto"
|
||||
@@ -85,15 +87,21 @@ def register_mock_resolver():
|
||||
del LoRAResolverRegistry.resolvers[MOCK_RESOLVER_NAME]
|
||||
|
||||
|
||||
def _build_renderer(model_config: MockModelConfig):
|
||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
||||
|
||||
return HfRenderer(
|
||||
model_config,
|
||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_serving_setup():
|
||||
"""Provides a mocked engine and serving completion instance."""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.errored = False
|
||||
|
||||
tokenizer = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.get_tokenizer = AsyncMock(return_value=tokenizer)
|
||||
|
||||
async def mock_add_lora_side_effect(lora_request: LoRARequest):
|
||||
"""Simulate engine behavior when adding LoRAs."""
|
||||
if lora_request.lora_name == "test-lora":
|
||||
@@ -118,6 +126,7 @@ def mock_serving_setup():
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
models = OpenAIServingModels(
|
||||
engine_client=mock_engine,
|
||||
@@ -128,10 +137,6 @@ def mock_serving_setup():
|
||||
mock_engine, models, request_logger=None
|
||||
)
|
||||
|
||||
serving_completion._process_inputs = AsyncMock(
|
||||
return_value=(MagicMock(name="engine_request"), {})
|
||||
)
|
||||
|
||||
return mock_engine, serving_completion
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import regex as re
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.renderer import CompletionRenderer
|
||||
from vllm.renderers.embed_utils import safe_load_prompt_embeds
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
@@ -30,7 +30,7 @@ async def test_empty_prompt():
|
||||
):
|
||||
await client.completions.create(
|
||||
model=model_name,
|
||||
prompt="",
|
||||
prompt=None,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": []},
|
||||
@@ -63,7 +63,6 @@ def test_load_prompt_embeds(
|
||||
):
|
||||
model_config = Mock(spec=ModelConfig)
|
||||
model_config.enable_prompt_embeds = True
|
||||
renderer = CompletionRenderer(model_config, tokenizer=None)
|
||||
|
||||
# construct arbitrary tensors of various dtypes, layouts, and sizes.
|
||||
# We need to check against different layouts to make sure that if a user
|
||||
@@ -89,9 +88,7 @@ def test_load_prompt_embeds(
|
||||
buffer.seek(0)
|
||||
encoded_tensor = pybase64.b64encode(buffer.getvalue())
|
||||
|
||||
loaded_prompt_embeds = renderer.load_prompt_embeds(encoded_tensor)
|
||||
assert len(loaded_prompt_embeds) == 1
|
||||
loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"]
|
||||
loaded_tensor = safe_load_prompt_embeds(model_config, encoded_tensor)
|
||||
assert loaded_tensor.device.type == "cpu"
|
||||
assert loaded_tensor.layout == torch.strided
|
||||
torch.testing.assert_close(
|
||||
@@ -105,7 +102,6 @@ def test_load_prompt_embeds(
|
||||
def test_disable_prompt_embeds(dtype: torch.dtype, seq_len: int, hidden_size: int):
|
||||
model_config = Mock(spec=ModelConfig)
|
||||
model_config.enable_prompt_embeds = False
|
||||
renderer = CompletionRenderer(model_config, tokenizer=None)
|
||||
|
||||
tensor = torch.randn((seq_len, hidden_size), dtype=dtype)
|
||||
|
||||
@@ -115,4 +111,4 @@ def test_disable_prompt_embeds(dtype: torch.dtype, seq_len: int, hidden_size: in
|
||||
encoded_tensor = pybase64.b64encode(buffer.getvalue())
|
||||
|
||||
with pytest.raises(ValueError, match="--enable-prompt-embeds"):
|
||||
renderer.load_prompt_embeds(encoded_tensor)
|
||||
safe_load_prompt_embeds(model_config, encoded_tensor)
|
||||
|
||||
@@ -556,19 +556,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
request_logger=None,
|
||||
)
|
||||
|
||||
async def _fake_process_inputs(
|
||||
request_id,
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
*,
|
||||
lora_request,
|
||||
trace_headers,
|
||||
priority,
|
||||
data_parallel_rank,
|
||||
):
|
||||
return dict(engine_prompt), {}
|
||||
|
||||
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
|
||||
return serving_chat
|
||||
|
||||
|
||||
@@ -784,7 +771,7 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated():
|
||||
|
||||
resp = await serving_chat.create_chat_completion(req)
|
||||
assert isinstance(resp, ErrorResponse)
|
||||
assert "max_tokens" in resp.error.message
|
||||
assert "context length is only" in resp.error.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -824,7 +811,7 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected():
|
||||
|
||||
resp = await serving_chat.create_chat_completion(req)
|
||||
assert isinstance(resp, ErrorResponse)
|
||||
assert "maximum context length" in resp.error.message
|
||||
assert "context length is only" in resp.error.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -890,6 +877,20 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
orig_render_chat_request = serving_chat.render_chat_request
|
||||
captured_prompts = []
|
||||
|
||||
async def render_chat_request(request):
|
||||
result = await orig_render_chat_request(request)
|
||||
|
||||
assert isinstance(result, tuple)
|
||||
conversation, engine_prompts = result
|
||||
captured_prompts.extend(engine_prompts)
|
||||
|
||||
return result
|
||||
|
||||
serving_chat.render_chat_request = render_chat_request
|
||||
|
||||
# Test cache_salt
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
@@ -899,15 +900,19 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
||||
# By default, cache_salt in the engine prompt is not set
|
||||
with suppress(Exception):
|
||||
await serving_chat.create_chat_completion(req)
|
||||
engine_prompt = serving_chat._process_inputs.await_args_list[0].args[1]
|
||||
assert "cache_salt" not in engine_prompt
|
||||
|
||||
assert len(captured_prompts) == 1
|
||||
assert "cache_salt" not in captured_prompts[0]
|
||||
|
||||
captured_prompts.clear()
|
||||
|
||||
# Test with certain cache_salt
|
||||
req.cache_salt = "test_salt"
|
||||
with suppress(Exception):
|
||||
await serving_chat.create_chat_completion(req)
|
||||
engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
|
||||
assert engine_prompt.get("cache_salt") == "test_salt"
|
||||
|
||||
assert len(captured_prompts) == 1
|
||||
assert captured_prompts[0]["cache_salt"] == "test_salt"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -1007,11 +1012,11 @@ class TestServingChatWithHarmony:
|
||||
@pytest.fixture()
|
||||
def mock_engine(self) -> AsyncLLM:
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
return mock_engine
|
||||
|
||||
@pytest.fixture()
|
||||
@@ -1618,11 +1623,11 @@ async def test_tool_choice_validation_without_parser():
|
||||
"""Test that tool_choice='required' or named tool without tool_parser
|
||||
returns an appropriate error message."""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
models = OpenAIServingModels(
|
||||
engine_client=mock_engine,
|
||||
|
||||
@@ -1,341 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Sparse tensor validation in embedding APIs.
|
||||
|
||||
Tests verify that malicious sparse tensors are rejected before they can trigger
|
||||
out-of-bounds memory writes during to_dense() operations.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.entrypoints.renderer import CompletionRenderer
|
||||
from vllm.multimodal.media import AudioEmbeddingMediaIO, ImageEmbeddingMediaIO
|
||||
|
||||
|
||||
def _encode_tensor(tensor: torch.Tensor) -> bytes:
|
||||
"""Helper to encode a tensor as base64 bytes."""
|
||||
buffer = io.BytesIO()
|
||||
torch.save(tensor, buffer)
|
||||
buffer.seek(0)
|
||||
return base64.b64encode(buffer.read())
|
||||
|
||||
|
||||
def _create_malicious_sparse_tensor() -> torch.Tensor:
|
||||
"""
|
||||
Create a malicious sparse COO tensor with out-of-bounds indices.
|
||||
|
||||
This tensor has indices that point beyond the declared shape, which would
|
||||
cause an out-of-bounds write when converted to dense format without
|
||||
validation.
|
||||
"""
|
||||
# Create a 3x3 sparse tensor but with indices pointing to (10, 10)
|
||||
indices = torch.tensor([[10], [10]]) # Out of bounds for 3x3 shape
|
||||
values = torch.tensor([1.0])
|
||||
shape = (3, 3)
|
||||
|
||||
# Create sparse tensor (this will be invalid)
|
||||
sparse_tensor = torch.sparse_coo_tensor(indices, values, shape, dtype=torch.float32)
|
||||
return sparse_tensor
|
||||
|
||||
|
||||
def _create_valid_sparse_tensor() -> torch.Tensor:
|
||||
"""Create a valid sparse COO tensor for baseline testing."""
|
||||
indices = torch.tensor([[0, 1, 2], [0, 1, 2]])
|
||||
values = torch.tensor([1.0, 2.0, 3.0])
|
||||
shape = (3, 3)
|
||||
|
||||
sparse_tensor = torch.sparse_coo_tensor(indices, values, shape, dtype=torch.float32)
|
||||
return sparse_tensor
|
||||
|
||||
|
||||
def _create_valid_dense_tensor() -> torch.Tensor:
|
||||
"""Create a valid dense tensor for baseline testing."""
|
||||
return torch.randn(10, 768, dtype=torch.float32) # (seq_len, hidden_size)
|
||||
|
||||
|
||||
class TestPromptEmbedsValidation:
|
||||
"""Test sparse tensor validation in prompt embeddings (Completions API)."""
|
||||
|
||||
def test_valid_dense_tensor_accepted(self, model_config):
|
||||
"""Baseline: Valid dense tensors should work normally."""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
valid_tensor = _create_valid_dense_tensor()
|
||||
encoded = _encode_tensor(valid_tensor)
|
||||
|
||||
# Should not raise any exception
|
||||
result = renderer.load_prompt_embeds(encoded)
|
||||
assert len(result) == 1
|
||||
assert result[0]["prompt_embeds"].shape == valid_tensor.shape
|
||||
|
||||
def test_valid_sparse_tensor_accepted(self):
|
||||
"""Baseline: Valid sparse tensors should load successfully."""
|
||||
io_handler = ImageEmbeddingMediaIO()
|
||||
|
||||
valid_sparse = _create_valid_sparse_tensor()
|
||||
encoded = _encode_tensor(valid_sparse)
|
||||
|
||||
# Should not raise any exception (sparse tensors remain sparse)
|
||||
result = io_handler.load_base64("", encoded.decode("utf-8"))
|
||||
assert result.shape == valid_sparse.shape
|
||||
|
||||
def test_malicious_sparse_tensor_rejected(self, model_config):
|
||||
"""Security: Malicious sparse tensors should be rejected."""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
malicious_tensor = _create_malicious_sparse_tensor()
|
||||
encoded = _encode_tensor(malicious_tensor)
|
||||
|
||||
# Should raise RuntimeError due to invalid sparse tensor
|
||||
with pytest.raises((RuntimeError, ValueError)) as exc_info:
|
||||
renderer.load_prompt_embeds(encoded)
|
||||
|
||||
# Error should indicate sparse tensor validation failure
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg
|
||||
|
||||
def test_extremely_large_indices_rejected(self, model_config):
|
||||
"""Security: Sparse tensors with extremely large indices should be rejected."""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
# Create tensor with indices far beyond reasonable bounds
|
||||
indices = torch.tensor([[999999], [999999]])
|
||||
values = torch.tensor([1.0])
|
||||
shape = (10, 10)
|
||||
|
||||
malicious_tensor = torch.sparse_coo_tensor(
|
||||
indices, values, shape, dtype=torch.float32
|
||||
)
|
||||
encoded = _encode_tensor(malicious_tensor)
|
||||
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
renderer.load_prompt_embeds(encoded)
|
||||
|
||||
def test_negative_indices_rejected(self, model_config):
|
||||
"""Security: Sparse tensors with negative indices should be rejected."""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
# Create tensor with negative indices
|
||||
indices = torch.tensor([[-1], [-1]])
|
||||
values = torch.tensor([1.0])
|
||||
shape = (10, 10)
|
||||
|
||||
malicious_tensor = torch.sparse_coo_tensor(
|
||||
indices, values, shape, dtype=torch.float32
|
||||
)
|
||||
encoded = _encode_tensor(malicious_tensor)
|
||||
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
renderer.load_prompt_embeds(encoded)
|
||||
|
||||
|
||||
class TestImageEmbedsValidation:
|
||||
"""Test sparse tensor validation in image embeddings (Chat API)."""
|
||||
|
||||
def test_valid_dense_tensor_accepted(self):
|
||||
"""Baseline: Valid dense tensors should work normally."""
|
||||
io_handler = ImageEmbeddingMediaIO()
|
||||
|
||||
valid_tensor = _create_valid_dense_tensor()
|
||||
encoded = _encode_tensor(valid_tensor)
|
||||
|
||||
# Should not raise any exception
|
||||
result = io_handler.load_base64("", encoded.decode("utf-8"))
|
||||
assert result.shape == valid_tensor.shape
|
||||
|
||||
def test_valid_sparse_tensor_accepted(self):
|
||||
"""Baseline: Valid sparse tensors should load successfully."""
|
||||
io_handler = AudioEmbeddingMediaIO()
|
||||
|
||||
valid_sparse = _create_valid_sparse_tensor()
|
||||
encoded = _encode_tensor(valid_sparse)
|
||||
|
||||
# Should not raise any exception (sparse tensors remain sparse)
|
||||
result = io_handler.load_base64("", encoded.decode("utf-8"))
|
||||
assert result.shape == valid_sparse.shape
|
||||
|
||||
def test_malicious_sparse_tensor_rejected(self):
|
||||
"""Security: Malicious sparse tensors should be rejected."""
|
||||
io_handler = ImageEmbeddingMediaIO()
|
||||
|
||||
malicious_tensor = _create_malicious_sparse_tensor()
|
||||
encoded = _encode_tensor(malicious_tensor)
|
||||
|
||||
# Should raise RuntimeError due to invalid sparse tensor
|
||||
with pytest.raises((RuntimeError, ValueError)) as exc_info:
|
||||
io_handler.load_base64("", encoded.decode("utf-8"))
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg
|
||||
|
||||
def test_load_bytes_validates(self):
|
||||
"""Security: Validation should also work for load_bytes method."""
|
||||
io_handler = ImageEmbeddingMediaIO()
|
||||
|
||||
malicious_tensor = _create_malicious_sparse_tensor()
|
||||
buffer = io.BytesIO()
|
||||
torch.save(malicious_tensor, buffer)
|
||||
buffer.seek(0)
|
||||
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
io_handler.load_bytes(buffer.read())
|
||||
|
||||
|
||||
class TestAudioEmbedsValidation:
|
||||
"""Test sparse tensor validation in audio embeddings (Chat API)."""
|
||||
|
||||
def test_valid_dense_tensor_accepted(self):
|
||||
"""Baseline: Valid dense tensors should work normally."""
|
||||
io_handler = AudioEmbeddingMediaIO()
|
||||
|
||||
valid_tensor = _create_valid_dense_tensor()
|
||||
encoded = _encode_tensor(valid_tensor)
|
||||
|
||||
# Should not raise any exception
|
||||
result = io_handler.load_base64("", encoded.decode("utf-8"))
|
||||
assert result.shape == valid_tensor.shape
|
||||
|
||||
def test_valid_sparse_tensor_accepted(self):
|
||||
"""Baseline: Valid sparse tensors should be converted successfully."""
|
||||
io_handler = AudioEmbeddingMediaIO()
|
||||
|
||||
valid_sparse = _create_valid_sparse_tensor()
|
||||
encoded = _encode_tensor(valid_sparse)
|
||||
|
||||
# Should not raise any exception
|
||||
result = io_handler.load_base64("", encoded.decode("utf-8"))
|
||||
assert result.is_sparse is False
|
||||
|
||||
def test_malicious_sparse_tensor_rejected(self):
|
||||
"""Security: Malicious sparse tensors should be rejected."""
|
||||
io_handler = AudioEmbeddingMediaIO()
|
||||
|
||||
malicious_tensor = _create_malicious_sparse_tensor()
|
||||
encoded = _encode_tensor(malicious_tensor)
|
||||
|
||||
# Should raise RuntimeError due to invalid sparse tensor
|
||||
with pytest.raises((RuntimeError, ValueError)) as exc_info:
|
||||
io_handler.load_base64("", encoded.decode("utf-8"))
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg
|
||||
|
||||
def test_load_bytes_validates(self):
|
||||
"""Security: Validation should also work for load_bytes method."""
|
||||
io_handler = AudioEmbeddingMediaIO()
|
||||
|
||||
malicious_tensor = _create_malicious_sparse_tensor()
|
||||
buffer = io.BytesIO()
|
||||
torch.save(malicious_tensor, buffer)
|
||||
buffer.seek(0)
|
||||
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
io_handler.load_bytes(buffer.read())
|
||||
|
||||
|
||||
class TestSparseTensorValidationIntegration:
|
||||
"""
|
||||
These tests verify the complete attack chain is blocked at all entry points.
|
||||
"""
|
||||
|
||||
def test_attack_scenario_completions_api(self, model_config):
|
||||
"""
|
||||
Simulate a complete attack through the Completions API.
|
||||
|
||||
Attack scenario:
|
||||
1. Attacker crafts malicious sparse tensor
|
||||
2. Encodes it as base64
|
||||
3. Sends to /v1/completions with prompt_embeds parameter
|
||||
4. Server should reject before memory corruption occurs
|
||||
"""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
# Step 1-2: Attacker creates malicious payload
|
||||
attack_payload = _encode_tensor(_create_malicious_sparse_tensor())
|
||||
|
||||
# Step 3-4: Server processes and should reject
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
renderer.load_prompt_embeds(attack_payload)
|
||||
|
||||
def test_attack_scenario_chat_api_image(self):
|
||||
"""
|
||||
Simulate attack through Chat API with image_embeds.
|
||||
|
||||
Verifies the image embeddings path is protected.
|
||||
"""
|
||||
io_handler = ImageEmbeddingMediaIO()
|
||||
attack_payload = _encode_tensor(_create_malicious_sparse_tensor())
|
||||
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
io_handler.load_base64("", attack_payload.decode("utf-8"))
|
||||
|
||||
def test_attack_scenario_chat_api_audio(self):
|
||||
"""
|
||||
Simulate attack through Chat API with audio_embeds.
|
||||
|
||||
Verifies the audio embeddings path is protected.
|
||||
"""
|
||||
io_handler = AudioEmbeddingMediaIO()
|
||||
attack_payload = _encode_tensor(_create_malicious_sparse_tensor())
|
||||
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
io_handler.load_base64("", attack_payload.decode("utf-8"))
|
||||
|
||||
def test_multiple_valid_embeddings_in_batch(self, model_config):
|
||||
"""
|
||||
Regression test: Multiple valid embeddings should still work.
|
||||
|
||||
Ensures the fix doesn't break legitimate batch processing.
|
||||
"""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
valid_tensors = [
|
||||
_encode_tensor(_create_valid_dense_tensor()),
|
||||
_encode_tensor(_create_valid_dense_tensor()),
|
||||
_encode_tensor(_create_valid_dense_tensor()),
|
||||
]
|
||||
|
||||
# Should process all without error
|
||||
result = renderer.load_prompt_embeds(valid_tensors)
|
||||
assert len(result) == 3
|
||||
|
||||
def test_mixed_valid_and_malicious_rejected(self, model_config):
|
||||
"""
|
||||
Security: Batch with one malicious tensor should be rejected.
|
||||
|
||||
Even if most tensors are valid, a single malicious one should
|
||||
cause rejection of the entire batch.
|
||||
"""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
mixed_batch = [
|
||||
_encode_tensor(_create_valid_dense_tensor()),
|
||||
_encode_tensor(_create_malicious_sparse_tensor()), # Malicious
|
||||
_encode_tensor(_create_valid_dense_tensor()),
|
||||
]
|
||||
|
||||
# Should fail on the malicious tensor
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
renderer.load_prompt_embeds(mixed_batch)
|
||||
|
||||
|
||||
# Pytest fixtures
|
||||
@pytest.fixture
|
||||
def model_config():
|
||||
"""Mock ModelConfig for testing."""
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
return ModelConfig(
|
||||
model="facebook/opt-125m",
|
||||
tokenizer="facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float32",
|
||||
seed=0,
|
||||
enable_prompt_embeds=True, # Required for prompt embeds tests
|
||||
)
|
||||
Reference in New Issue
Block a user