[Frontend] Use new Renderer for Completions and Tokenize API (#32863)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-31 20:51:15 +08:00
committed by GitHub
parent 8980001c93
commit f0a1c8453a
64 changed files with 2116 additions and 2003 deletions

View File

@@ -205,7 +205,7 @@ def test_chat_batch_failure_cleanup(llm_for_failure_test):
valid_msg,
]
sampling_params = SamplingParams(temperature=0, max_tokens=10)
with pytest.raises(ValueError, match="longer than the maximum model length"):
with pytest.raises(ValueError, match="context length is only"):
llm.chat(batch_1, sampling_params=sampling_params)
outputs_2 = llm.chat(batch_2, sampling_params=sampling_params)
assert len(outputs_2) == len(batch_2)

View File

@@ -15,7 +15,8 @@ from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.tokenizers import get_tokenizer
from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2"
@@ -57,6 +58,15 @@ class MockModelConfig:
return self.diff_sampling_param or {}
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer(
model_config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
models = OpenAIServingModels(
engine_client=engine,
@@ -71,18 +81,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
chat_template_content_format="auto",
)
async def _fake_process_inputs(
request_id,
engine_prompt,
sampling_params,
*,
lora_request,
trace_headers,
priority,
data_parallel_rank,
):
return dict(engine_prompt), {}
async def _fake_preprocess_chat(*args, **kwargs):
# return conversation, engine_prompts
return (
@@ -90,7 +88,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
[{"prompt_token_ids": [1, 2, 3]}],
)
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
serving_chat._preprocess_chat = AsyncMock(side_effect=_fake_preprocess_chat)
return serving_chat
@@ -99,11 +96,11 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
async def test_chat_error_non_stream():
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)
@@ -153,11 +150,11 @@ async def test_chat_error_non_stream():
async def test_chat_error_stream():
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)

View File

@@ -4,7 +4,7 @@
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import MagicMock
import pytest
@@ -15,7 +15,8 @@ from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.tokenizers import get_tokenizer
from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2"
@@ -61,37 +62,31 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
engine_client=engine,
base_model_paths=BASE_MODEL_PATHS,
)
serving_completion = OpenAIServingCompletion(
return OpenAIServingCompletion(
engine,
models,
request_logger=None,
)
async def _fake_process_inputs(
request_id,
engine_prompt,
sampling_params,
*,
lora_request,
trace_headers,
priority,
data_parallel_rank,
):
return dict(engine_prompt), {}
serving_completion._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
return serving_completion
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer(
model_config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
@pytest.mark.asyncio
async def test_completion_error_non_stream():
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_completion = _build_serving_completion(mock_engine)
@@ -141,11 +136,11 @@ async def test_completion_error_non_stream():
async def test_completion_error_stream():
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_completion = _build_serving_completion(mock_engine)

View File

@@ -110,7 +110,7 @@ async def test_completions_with_prompt_embeds(
# Test case: Single prompt embeds input
completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
prompt=None,
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds},
@@ -121,7 +121,7 @@ async def test_completions_with_prompt_embeds(
# Test case: batch completion with prompt_embeds
completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
prompt=None,
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]},
@@ -133,7 +133,7 @@ async def test_completions_with_prompt_embeds(
# Test case: streaming with prompt_embeds
single_completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
prompt=None,
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds},
@@ -142,7 +142,7 @@ async def test_completions_with_prompt_embeds(
stream = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
prompt=None,
max_tokens=5,
temperature=0.0,
stream=True,
@@ -162,7 +162,7 @@ async def test_completions_with_prompt_embeds(
# Test case: batch streaming with prompt_embeds
stream = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
prompt=None,
max_tokens=5,
temperature=0.0,
stream=True,
@@ -197,7 +197,7 @@ async def test_completions_with_prompt_embeds(
)
completion_embeds_only = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="",
prompt=None,
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds},
@@ -215,7 +215,7 @@ async def test_completions_errors_with_prompt_embeds(
# Test error case: invalid prompt_embeds
with pytest.raises(BadRequestError):
await client_with_prompt_embeds.completions.create(
prompt="",
prompt=None,
model=model_name,
max_tokens=5,
temperature=0.0,
@@ -237,7 +237,7 @@ async def test_completions_with_logprobs_and_prompt_embeds(
# Test case: Logprobs using prompt_embeds
completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
prompt=None,
max_tokens=5,
temperature=0.0,
echo=False,
@@ -257,7 +257,7 @@ async def test_completions_with_logprobs_and_prompt_embeds(
# Test case: Log probs with batch completion and prompt_embeds
completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
prompt=None,
max_tokens=5,
temperature=0.0,
echo=False,
@@ -287,7 +287,7 @@ async def test_prompt_logprobs_raises_error(
with pytest.raises(BadRequestError, match="not compatible"):
await client_with_prompt_embeds.completions.create(
model=MODEL_NAME,
prompt="",
prompt=None,
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds, "prompt_logprobs": True},

View File

@@ -7,7 +7,7 @@ Tests verify that embeddings with correct ndim but incorrect hidden_size
are rejected before they can cause crashes during model inference.
Validation is performed by the parser (MultiModalDataParser) and EmbeddingItems
classes, not by CompletionRenderer or MediaIO classes.
classes, not by MediaIO classes.
"""
import pytest

View File

@@ -16,7 +16,8 @@ from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
from vllm.tokenizers import get_tokenizer
from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2"
@@ -35,6 +36,7 @@ class MockModelConfig:
"""Minimal mock ModelConfig for testing."""
model: str = MODEL_NAME
runner_type = "generate"
tokenizer: str = MODEL_NAME
trust_remote_code: bool = False
tokenizer_mode: str = "auto"
@@ -85,15 +87,21 @@ def register_mock_resolver():
del LoRAResolverRegistry.resolvers[MOCK_RESOLVER_NAME]
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer(
model_config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
@pytest.fixture
def mock_serving_setup():
"""Provides a mocked engine and serving completion instance."""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
tokenizer = get_tokenizer(MODEL_NAME)
mock_engine.get_tokenizer = AsyncMock(return_value=tokenizer)
async def mock_add_lora_side_effect(lora_request: LoRARequest):
"""Simulate engine behavior when adding LoRAs."""
if lora_request.lora_name == "test-lora":
@@ -118,6 +126,7 @@ def mock_serving_setup():
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
models = OpenAIServingModels(
engine_client=mock_engine,
@@ -128,10 +137,6 @@ def mock_serving_setup():
mock_engine, models, request_logger=None
)
serving_completion._process_inputs = AsyncMock(
return_value=(MagicMock(name="engine_request"), {})
)
return mock_engine, serving_completion

View File

@@ -12,7 +12,7 @@ import regex as re
import torch
from vllm.config import ModelConfig
from vllm.entrypoints.renderer import CompletionRenderer
from vllm.renderers.embed_utils import safe_load_prompt_embeds
from ...utils import RemoteOpenAIServer
@@ -30,7 +30,7 @@ async def test_empty_prompt():
):
await client.completions.create(
model=model_name,
prompt="",
prompt=None,
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": []},
@@ -63,7 +63,6 @@ def test_load_prompt_embeds(
):
model_config = Mock(spec=ModelConfig)
model_config.enable_prompt_embeds = True
renderer = CompletionRenderer(model_config, tokenizer=None)
# construct arbitrary tensors of various dtypes, layouts, and sizes.
# We need to check against different layouts to make sure that if a user
@@ -89,9 +88,7 @@ def test_load_prompt_embeds(
buffer.seek(0)
encoded_tensor = pybase64.b64encode(buffer.getvalue())
loaded_prompt_embeds = renderer.load_prompt_embeds(encoded_tensor)
assert len(loaded_prompt_embeds) == 1
loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"]
loaded_tensor = safe_load_prompt_embeds(model_config, encoded_tensor)
assert loaded_tensor.device.type == "cpu"
assert loaded_tensor.layout == torch.strided
torch.testing.assert_close(
@@ -105,7 +102,6 @@ def test_load_prompt_embeds(
def test_disable_prompt_embeds(dtype: torch.dtype, seq_len: int, hidden_size: int):
model_config = Mock(spec=ModelConfig)
model_config.enable_prompt_embeds = False
renderer = CompletionRenderer(model_config, tokenizer=None)
tensor = torch.randn((seq_len, hidden_size), dtype=dtype)
@@ -115,4 +111,4 @@ def test_disable_prompt_embeds(dtype: torch.dtype, seq_len: int, hidden_size: in
encoded_tensor = pybase64.b64encode(buffer.getvalue())
with pytest.raises(ValueError, match="--enable-prompt-embeds"):
renderer.load_prompt_embeds(encoded_tensor)
safe_load_prompt_embeds(model_config, encoded_tensor)

View File

@@ -556,19 +556,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
request_logger=None,
)
async def _fake_process_inputs(
request_id,
engine_prompt,
sampling_params,
*,
lora_request,
trace_headers,
priority,
data_parallel_rank,
):
return dict(engine_prompt), {}
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
return serving_chat
@@ -784,7 +771,7 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated():
resp = await serving_chat.create_chat_completion(req)
assert isinstance(resp, ErrorResponse)
assert "max_tokens" in resp.error.message
assert "context length is only" in resp.error.message
@pytest.mark.asyncio
@@ -824,7 +811,7 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected():
resp = await serving_chat.create_chat_completion(req)
assert isinstance(resp, ErrorResponse)
assert "maximum context length" in resp.error.message
assert "context length is only" in resp.error.message
@pytest.mark.asyncio
@@ -890,6 +877,20 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
serving_chat = _build_serving_chat(mock_engine)
orig_render_chat_request = serving_chat.render_chat_request
captured_prompts = []
async def render_chat_request(request):
result = await orig_render_chat_request(request)
assert isinstance(result, tuple)
conversation, engine_prompts = result
captured_prompts.extend(engine_prompts)
return result
serving_chat.render_chat_request = render_chat_request
# Test cache_salt
req = ChatCompletionRequest(
model=MODEL_NAME,
@@ -899,15 +900,19 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
# By default, cache_salt in the engine prompt is not set
with suppress(Exception):
await serving_chat.create_chat_completion(req)
engine_prompt = serving_chat._process_inputs.await_args_list[0].args[1]
assert "cache_salt" not in engine_prompt
assert len(captured_prompts) == 1
assert "cache_salt" not in captured_prompts[0]
captured_prompts.clear()
# Test with certain cache_salt
req.cache_salt = "test_salt"
with suppress(Exception):
await serving_chat.create_chat_completion(req)
engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
assert engine_prompt.get("cache_salt") == "test_salt"
assert len(captured_prompts) == 1
assert captured_prompts[0]["cache_salt"] == "test_salt"
@pytest.mark.asyncio
@@ -1007,11 +1012,11 @@ class TestServingChatWithHarmony:
@pytest.fixture()
def mock_engine(self) -> AsyncLLM:
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
return mock_engine
@pytest.fixture()
@@ -1618,11 +1623,11 @@ async def test_tool_choice_validation_without_parser():
"""Test that tool_choice='required' or named tool without tool_parser
returns an appropriate error message."""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
models = OpenAIServingModels(
engine_client=mock_engine,

View File

@@ -1,341 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Sparse tensor validation in embedding APIs.
Tests verify that malicious sparse tensors are rejected before they can trigger
out-of-bounds memory writes during to_dense() operations.
"""
import base64
import io
import pytest
import torch
from vllm.entrypoints.renderer import CompletionRenderer
from vllm.multimodal.media import AudioEmbeddingMediaIO, ImageEmbeddingMediaIO
def _encode_tensor(tensor: torch.Tensor) -> bytes:
"""Helper to encode a tensor as base64 bytes."""
buffer = io.BytesIO()
torch.save(tensor, buffer)
buffer.seek(0)
return base64.b64encode(buffer.read())
def _create_malicious_sparse_tensor() -> torch.Tensor:
"""
Create a malicious sparse COO tensor with out-of-bounds indices.
This tensor has indices that point beyond the declared shape, which would
cause an out-of-bounds write when converted to dense format without
validation.
"""
# Create a 3x3 sparse tensor but with indices pointing to (10, 10)
indices = torch.tensor([[10], [10]]) # Out of bounds for 3x3 shape
values = torch.tensor([1.0])
shape = (3, 3)
# Create sparse tensor (this will be invalid)
sparse_tensor = torch.sparse_coo_tensor(indices, values, shape, dtype=torch.float32)
return sparse_tensor
def _create_valid_sparse_tensor() -> torch.Tensor:
"""Create a valid sparse COO tensor for baseline testing."""
indices = torch.tensor([[0, 1, 2], [0, 1, 2]])
values = torch.tensor([1.0, 2.0, 3.0])
shape = (3, 3)
sparse_tensor = torch.sparse_coo_tensor(indices, values, shape, dtype=torch.float32)
return sparse_tensor
def _create_valid_dense_tensor() -> torch.Tensor:
"""Create a valid dense tensor for baseline testing."""
return torch.randn(10, 768, dtype=torch.float32) # (seq_len, hidden_size)
class TestPromptEmbedsValidation:
"""Test sparse tensor validation in prompt embeddings (Completions API)."""
def test_valid_dense_tensor_accepted(self, model_config):
"""Baseline: Valid dense tensors should work normally."""
renderer = CompletionRenderer(model_config)
valid_tensor = _create_valid_dense_tensor()
encoded = _encode_tensor(valid_tensor)
# Should not raise any exception
result = renderer.load_prompt_embeds(encoded)
assert len(result) == 1
assert result[0]["prompt_embeds"].shape == valid_tensor.shape
def test_valid_sparse_tensor_accepted(self):
"""Baseline: Valid sparse tensors should load successfully."""
io_handler = ImageEmbeddingMediaIO()
valid_sparse = _create_valid_sparse_tensor()
encoded = _encode_tensor(valid_sparse)
# Should not raise any exception (sparse tensors remain sparse)
result = io_handler.load_base64("", encoded.decode("utf-8"))
assert result.shape == valid_sparse.shape
def test_malicious_sparse_tensor_rejected(self, model_config):
"""Security: Malicious sparse tensors should be rejected."""
renderer = CompletionRenderer(model_config)
malicious_tensor = _create_malicious_sparse_tensor()
encoded = _encode_tensor(malicious_tensor)
# Should raise RuntimeError due to invalid sparse tensor
with pytest.raises((RuntimeError, ValueError)) as exc_info:
renderer.load_prompt_embeds(encoded)
# Error should indicate sparse tensor validation failure
error_msg = str(exc_info.value).lower()
assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg
def test_extremely_large_indices_rejected(self, model_config):
"""Security: Sparse tensors with extremely large indices should be rejected."""
renderer = CompletionRenderer(model_config)
# Create tensor with indices far beyond reasonable bounds
indices = torch.tensor([[999999], [999999]])
values = torch.tensor([1.0])
shape = (10, 10)
malicious_tensor = torch.sparse_coo_tensor(
indices, values, shape, dtype=torch.float32
)
encoded = _encode_tensor(malicious_tensor)
with pytest.raises((RuntimeError, ValueError)):
renderer.load_prompt_embeds(encoded)
def test_negative_indices_rejected(self, model_config):
"""Security: Sparse tensors with negative indices should be rejected."""
renderer = CompletionRenderer(model_config)
# Create tensor with negative indices
indices = torch.tensor([[-1], [-1]])
values = torch.tensor([1.0])
shape = (10, 10)
malicious_tensor = torch.sparse_coo_tensor(
indices, values, shape, dtype=torch.float32
)
encoded = _encode_tensor(malicious_tensor)
with pytest.raises((RuntimeError, ValueError)):
renderer.load_prompt_embeds(encoded)
class TestImageEmbedsValidation:
"""Test sparse tensor validation in image embeddings (Chat API)."""
def test_valid_dense_tensor_accepted(self):
"""Baseline: Valid dense tensors should work normally."""
io_handler = ImageEmbeddingMediaIO()
valid_tensor = _create_valid_dense_tensor()
encoded = _encode_tensor(valid_tensor)
# Should not raise any exception
result = io_handler.load_base64("", encoded.decode("utf-8"))
assert result.shape == valid_tensor.shape
def test_valid_sparse_tensor_accepted(self):
"""Baseline: Valid sparse tensors should load successfully."""
io_handler = AudioEmbeddingMediaIO()
valid_sparse = _create_valid_sparse_tensor()
encoded = _encode_tensor(valid_sparse)
# Should not raise any exception (sparse tensors remain sparse)
result = io_handler.load_base64("", encoded.decode("utf-8"))
assert result.shape == valid_sparse.shape
def test_malicious_sparse_tensor_rejected(self):
"""Security: Malicious sparse tensors should be rejected."""
io_handler = ImageEmbeddingMediaIO()
malicious_tensor = _create_malicious_sparse_tensor()
encoded = _encode_tensor(malicious_tensor)
# Should raise RuntimeError due to invalid sparse tensor
with pytest.raises((RuntimeError, ValueError)) as exc_info:
io_handler.load_base64("", encoded.decode("utf-8"))
error_msg = str(exc_info.value).lower()
assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg
def test_load_bytes_validates(self):
"""Security: Validation should also work for load_bytes method."""
io_handler = ImageEmbeddingMediaIO()
malicious_tensor = _create_malicious_sparse_tensor()
buffer = io.BytesIO()
torch.save(malicious_tensor, buffer)
buffer.seek(0)
with pytest.raises((RuntimeError, ValueError)):
io_handler.load_bytes(buffer.read())
class TestAudioEmbedsValidation:
"""Test sparse tensor validation in audio embeddings (Chat API)."""
def test_valid_dense_tensor_accepted(self):
"""Baseline: Valid dense tensors should work normally."""
io_handler = AudioEmbeddingMediaIO()
valid_tensor = _create_valid_dense_tensor()
encoded = _encode_tensor(valid_tensor)
# Should not raise any exception
result = io_handler.load_base64("", encoded.decode("utf-8"))
assert result.shape == valid_tensor.shape
def test_valid_sparse_tensor_accepted(self):
"""Baseline: Valid sparse tensors should be converted successfully."""
io_handler = AudioEmbeddingMediaIO()
valid_sparse = _create_valid_sparse_tensor()
encoded = _encode_tensor(valid_sparse)
# Should not raise any exception
result = io_handler.load_base64("", encoded.decode("utf-8"))
assert result.is_sparse is False
def test_malicious_sparse_tensor_rejected(self):
"""Security: Malicious sparse tensors should be rejected."""
io_handler = AudioEmbeddingMediaIO()
malicious_tensor = _create_malicious_sparse_tensor()
encoded = _encode_tensor(malicious_tensor)
# Should raise RuntimeError due to invalid sparse tensor
with pytest.raises((RuntimeError, ValueError)) as exc_info:
io_handler.load_base64("", encoded.decode("utf-8"))
error_msg = str(exc_info.value).lower()
assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg
def test_load_bytes_validates(self):
"""Security: Validation should also work for load_bytes method."""
io_handler = AudioEmbeddingMediaIO()
malicious_tensor = _create_malicious_sparse_tensor()
buffer = io.BytesIO()
torch.save(malicious_tensor, buffer)
buffer.seek(0)
with pytest.raises((RuntimeError, ValueError)):
io_handler.load_bytes(buffer.read())
class TestSparseTensorValidationIntegration:
"""
These tests verify the complete attack chain is blocked at all entry points.
"""
def test_attack_scenario_completions_api(self, model_config):
"""
Simulate a complete attack through the Completions API.
Attack scenario:
1. Attacker crafts malicious sparse tensor
2. Encodes it as base64
3. Sends to /v1/completions with prompt_embeds parameter
4. Server should reject before memory corruption occurs
"""
renderer = CompletionRenderer(model_config)
# Step 1-2: Attacker creates malicious payload
attack_payload = _encode_tensor(_create_malicious_sparse_tensor())
# Step 3-4: Server processes and should reject
with pytest.raises((RuntimeError, ValueError)):
renderer.load_prompt_embeds(attack_payload)
def test_attack_scenario_chat_api_image(self):
"""
Simulate attack through Chat API with image_embeds.
Verifies the image embeddings path is protected.
"""
io_handler = ImageEmbeddingMediaIO()
attack_payload = _encode_tensor(_create_malicious_sparse_tensor())
with pytest.raises((RuntimeError, ValueError)):
io_handler.load_base64("", attack_payload.decode("utf-8"))
def test_attack_scenario_chat_api_audio(self):
"""
Simulate attack through Chat API with audio_embeds.
Verifies the audio embeddings path is protected.
"""
io_handler = AudioEmbeddingMediaIO()
attack_payload = _encode_tensor(_create_malicious_sparse_tensor())
with pytest.raises((RuntimeError, ValueError)):
io_handler.load_base64("", attack_payload.decode("utf-8"))
def test_multiple_valid_embeddings_in_batch(self, model_config):
"""
Regression test: Multiple valid embeddings should still work.
Ensures the fix doesn't break legitimate batch processing.
"""
renderer = CompletionRenderer(model_config)
valid_tensors = [
_encode_tensor(_create_valid_dense_tensor()),
_encode_tensor(_create_valid_dense_tensor()),
_encode_tensor(_create_valid_dense_tensor()),
]
# Should process all without error
result = renderer.load_prompt_embeds(valid_tensors)
assert len(result) == 3
def test_mixed_valid_and_malicious_rejected(self, model_config):
"""
Security: Batch with one malicious tensor should be rejected.
Even if most tensors are valid, a single malicious one should
cause rejection of the entire batch.
"""
renderer = CompletionRenderer(model_config)
mixed_batch = [
_encode_tensor(_create_valid_dense_tensor()),
_encode_tensor(_create_malicious_sparse_tensor()), # Malicious
_encode_tensor(_create_valid_dense_tensor()),
]
# Should fail on the malicious tensor
with pytest.raises((RuntimeError, ValueError)):
renderer.load_prompt_embeds(mixed_batch)
# Pytest fixtures
@pytest.fixture
def model_config():
"""Mock ModelConfig for testing."""
from vllm.config import ModelConfig
return ModelConfig(
model="facebook/opt-125m",
tokenizer="facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float32",
seed=0,
enable_prompt_embeds=True, # Required for prompt embeds tests
)

View File

@@ -67,20 +67,6 @@ async def test_smaller_truncation_size(client: openai.AsyncOpenAI):
assert response["usage"]["prompt_tokens"] == truncation_size
@pytest.mark.asyncio
async def test_zero_truncation_size(client: openai.AsyncOpenAI):
truncation_size = 0
kwargs: dict[str, Any] = {
"model": MODEL_NAME,
"input": input,
"truncate_prompt_tokens": truncation_size,
}
response = await client.post(path="embeddings", cast_to=object, body={**kwargs})
assert response["usage"]["prompt_tokens"] == truncation_size
@pytest.mark.asyncio
async def test_bigger_truncation_size(client: openai.AsyncOpenAI):
truncation_size = max_model_len + 1

View File

@@ -128,12 +128,10 @@ def test_empty_input_error(server: RemoteOpenAIServer, model_name: str):
server.url_for("classify"),
json={"model": model_name, "input": []},
)
classification_response.raise_for_status()
output = ClassificationResponse.model_validate(classification_response.json())
assert output.object == "list"
assert isinstance(output.data, list)
assert len(output.data) == 0
error = classification_response.json()
assert classification_response.status_code == 400
assert "error" in error
@pytest.mark.parametrize("model_name", [MODEL_NAME])

View File

@@ -247,7 +247,7 @@ class TestModel:
},
)
assert score_response.status_code == 400
assert "Please, select a smaller truncation size." in score_response.text
assert "Please request a smaller truncation size." in score_response.text
def test_invocations(self, server: RemoteOpenAIServer, model: dict[str, Any]):
queries = "What is the capital of France?"

View File

@@ -1,325 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import io
from dataclasses import dataclass
from unittest.mock import AsyncMock, MagicMock
import pybase64
import pytest
import torch
from vllm.entrypoints.renderer import CompletionRenderer, RenderConfig
from vllm.inputs.data import is_embeds_prompt
@dataclass
class MockModelConfig:
max_model_len: int = 100
encoder_config: dict | None = None
enable_prompt_embeds: bool = True
class MockTokenizerResult:
def __init__(self, input_ids):
self.input_ids = input_ids
@pytest.fixture
def mock_model_config():
return MockModelConfig()
@pytest.fixture
def mock_tokenizer():
tokenizer = MagicMock()
return tokenizer
@pytest.fixture
def mock_async_tokenizer():
async_tokenizer = AsyncMock()
return async_tokenizer
@pytest.fixture
def renderer(mock_model_config, mock_tokenizer):
return CompletionRenderer(
model_config=mock_model_config,
tokenizer=mock_tokenizer,
async_tokenizer_pool={},
)
class TestRenderPrompt:
"""Test Category A: Basic Functionality Tests"""
@pytest.mark.asyncio
async def test_token_input(self, renderer):
tokens = [101, 7592, 2088]
results = await renderer.render_prompt(
prompt_or_prompts=tokens, config=RenderConfig(max_length=100)
)
assert len(results) == 1
assert results[0]["prompt_token_ids"] == tokens
@pytest.mark.asyncio
async def test_token_list_input(self, renderer):
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
results = await renderer.render_prompt(
prompt_or_prompts=token_lists, config=RenderConfig(max_length=100)
)
assert len(results) == 3
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012]
assert results[2]["prompt_token_ids"] == [103, 4567]
@pytest.mark.asyncio
async def test_text_input(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088])
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(
prompt_or_prompts="Hello world", config=RenderConfig(max_length=100)
)
assert len(results) == 1
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
mock_async_tokenizer.assert_called_once()
@pytest.mark.asyncio
async def test_text_list_input(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088])
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
text_list_input = ["Hello world", "How are you?", "Good morning"]
results = await renderer.render_prompt(
prompt_or_prompts=text_list_input, config=RenderConfig(max_length=100)
)
assert len(results) == 3
for result in results:
assert result["prompt_token_ids"] == [101, 7592, 2088]
assert mock_async_tokenizer.call_count == 3
@pytest.mark.asyncio
async def test_no_truncation(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088])
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(
prompt_or_prompts="Hello world", config=RenderConfig(max_length=100)
)
assert len(results) == 1
call_args = mock_async_tokenizer.call_args
assert (
"truncation" not in call_args.kwargs
or call_args.kwargs["truncation"] is False
)
@pytest.mark.asyncio
async def test_truncation_positive(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult(
[101, 7592, 2088]
) # Truncated
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(
prompt_or_prompts="Hello world",
config=RenderConfig(max_length=100, truncate_prompt_tokens=50),
)
assert len(results) == 1
call_args = mock_async_tokenizer.call_args
assert call_args.kwargs["truncation"] is True
assert call_args.kwargs["max_length"] == 50
@pytest.mark.asyncio
async def test_truncation_negative(self, renderer, mock_async_tokenizer):
# Test that negative truncation uses model's max_model_len
mock_async_tokenizer.return_value = MockTokenizerResult(
[101, 7592, 2088]
) # Truncated to max_model_len
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(
prompt_or_prompts="Hello world",
config=RenderConfig(max_length=200, truncate_prompt_tokens=-1),
)
assert len(results) == 1
call_args = mock_async_tokenizer.call_args
assert call_args.kwargs["truncation"] is True
assert call_args.kwargs["max_length"] == 100 # model's max_model_len
@pytest.mark.asyncio
async def test_token_truncation_last_elements(self, renderer):
# Test that token truncation keeps the last N elements
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
results = await renderer.render_prompt(
prompt_or_prompts=long_tokens,
config=RenderConfig(max_length=100, truncate_prompt_tokens=5),
)
assert len(results) == 1
# Should keep the last 5 tokens: [105, 106, 107, 108, 109]
assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109]
@pytest.mark.asyncio
async def test_max_length_exceeded(self, renderer):
long_tokens = list(range(150)) # Exceeds max_model_len=100
with pytest.raises(ValueError, match="maximum context length"):
await renderer.render_prompt(
prompt_or_prompts=long_tokens, config=RenderConfig(max_length=100)
)
@pytest.mark.asyncio
async def test_no_tokenizer_for_text(self, mock_model_config):
renderer_no_tokenizer = CompletionRenderer(
model_config=mock_model_config, tokenizer=None, async_tokenizer_pool={}
)
with pytest.raises(ValueError, match="No tokenizer available"):
await renderer_no_tokenizer.render_prompt(
prompt_or_prompts="Hello world", config=RenderConfig(max_length=100)
)
@pytest.mark.asyncio
async def test_token_input_with_needs_detokenization(
self, renderer, mock_async_tokenizer
):
# When needs_detokenization=True for token inputs, renderer should
# use the async tokenizer to decode and include the original text
# in the returned prompt object.
mock_async_tokenizer.decode = AsyncMock(return_value="decoded text")
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
tokens = [1, 2, 3, 4]
results = await renderer.render_prompt(
prompt_or_prompts=tokens,
config=RenderConfig(needs_detokenization=True),
)
assert len(results) == 1
assert results[0]["prompt_token_ids"] == tokens
assert results[0]["prompt"] == "decoded text"
mock_async_tokenizer.decode.assert_awaited_once()
class TestRenderEmbedPrompt:
def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes:
"""Helper to create base64-encoded tensor bytes"""
buffer = io.BytesIO()
torch.save(tensor, buffer)
buffer.seek(0)
return pybase64.b64encode(buffer.read())
@pytest.mark.asyncio
async def test_single_prompt_embed(self, renderer):
# Create a test tensor
test_tensor = torch.randn(10, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes,
config=RenderConfig(cache_salt="test_salt"),
)
assert len(results) == 1
assert is_embeds_prompt(results[0])
assert torch.allclose(results[0]["prompt_embeds"], test_tensor)
assert results[0]["cache_salt"] == "test_salt"
@pytest.mark.asyncio
async def test_multiple_prompt_embeds(self, renderer):
# Create multiple test tensors
test_tensors = [
torch.randn(8, 512, dtype=torch.float32),
torch.randn(12, 512, dtype=torch.float32),
]
embed_bytes_list = [self._create_test_embed_bytes(t) for t in test_tensors]
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes_list,
config=RenderConfig(),
)
assert len(results) == 2
for i, result in enumerate(results):
assert is_embeds_prompt(result)
assert torch.allclose(result["prompt_embeds"], test_tensors[i])
@pytest.mark.asyncio
async def test_prompt_embed_truncation(self, renderer):
# Create tensor with more tokens than truncation limit
test_tensor = torch.randn(20, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes,
config=RenderConfig(truncate_prompt_tokens=10),
)
assert len(results) == 1
# Should keep last 10 tokens
expected = test_tensor[-10:]
assert torch.allclose(results[0]["prompt_embeds"], expected)
@pytest.mark.asyncio
async def test_prompt_embed_different_dtypes(self, renderer):
# Test different supported dtypes
dtypes = [torch.float32, torch.float16, torch.bfloat16]
for dtype in dtypes:
test_tensor = torch.randn(5, 256, dtype=dtype)
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes,
config=RenderConfig(),
)
assert len(results) == 1
assert results[0]["prompt_embeds"].dtype == dtype
@pytest.mark.asyncio
async def test_prompt_embed_squeeze_batch_dim(self, renderer):
# Test tensor with batch dimension gets squeezed
test_tensor = torch.randn(1, 10, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes,
config=RenderConfig(),
)
assert len(results) == 1
# Should be squeezed to 2D
assert results[0]["prompt_embeds"].shape == (10, 768)
@pytest.mark.asyncio
async def test_both_prompts_and_embeds(self, renderer, mock_async_tokenizer):
# Set up text tokenization
mock_async_tokenizer.return_value = MockTokenizerResult([101, 102, 103])
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
# Create embed
test_tensor = torch.randn(5, 256, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_or_prompts="Hello world",
prompt_embeds=embed_bytes,
config=RenderConfig(),
)
assert len(results) == 2
# First should be embed prompt
assert is_embeds_prompt(results[0])
# Second should be tokens prompt
assert "prompt_token_ids" in results[1]
assert results[1]["prompt_token_ids"] == [101, 102, 103]