Files
vllm/tests/renderers/test_completions.py
2026-01-31 04:51:15 -08:00

427 lines
15 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import io
from dataclasses import dataclass
from typing import Any
from unittest.mock import AsyncMock
import pybase64
import pytest
import torch
from vllm.inputs.data import is_embeds_prompt
from vllm.renderers import TokenizeParams
from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config
MODEL_NAME = "openai-community/gpt2"
@dataclass
class MockHFConfig:
model_type: str = "any"
@dataclass
class MockModelConfig:
runner_type = "generate"
model: str = MODEL_NAME
tokenizer: str = MODEL_NAME
trust_remote_code: bool = False
max_model_len: int = 100
tokenizer_revision = None
tokenizer_mode = "auto"
hf_config = MockHFConfig()
encoder_config: dict[str, Any] | None = None
enable_prompt_embeds: bool = True
skip_tokenizer_init: bool = False
@pytest.fixture
def mock_model_config():
return MockModelConfig()
@pytest.fixture
def mock_async_tokenizer():
return AsyncMock()
@pytest.fixture
def renderer(mock_model_config):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(mock_model_config)
return HfRenderer(
mock_model_config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
class TestValidatePrompt:
STRING_INPUTS = [
"",
"foo",
"foo bar",
"foo baz bar",
"foo bar qux baz",
]
TOKEN_INPUTS = [
[-1],
[1],
[1, 2],
[1, 3, 4],
[1, 2, 4, 3],
]
INPUTS_SLICES = [
slice(None, None, -1),
slice(None, None, 2),
slice(None, None, -2),
]
# Test that a nested mixed-type list of lists raises a TypeError.
def test_empty_input(self, renderer):
with pytest.raises(ValueError, match="at least one prompt"):
renderer.render_completions([])
def test_invalid_type(self, renderer):
with pytest.raises(TypeError, match="string or an array of tokens"):
renderer.render_completions([[1, 2], ["foo", "bar"]])
@pytest.mark.parametrize("string_input", STRING_INPUTS)
def test_string_consistent(self, renderer, string_input: str):
assert renderer.render_completions(string_input) == renderer.render_completions(
[string_input]
)
@pytest.mark.parametrize("token_input", TOKEN_INPUTS)
def test_token_consistent(self, renderer, token_input: list[int]):
assert renderer.render_completions(token_input) == renderer.render_completions(
[token_input]
)
@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES)
def test_string_slice(self, renderer, inputs_slice: slice):
assert renderer.render_completions(self.STRING_INPUTS)[
inputs_slice
] == renderer.render_completions(self.STRING_INPUTS[inputs_slice])
class TestRenderPrompt:
@pytest.mark.asyncio
async def test_token_input(self, renderer):
tokens = [101, 7592, 2088]
prompts = await renderer.render_completions_async(tokens)
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=100),
)
assert len(results) == 1
assert results[0]["prompt_token_ids"] == tokens
@pytest.mark.asyncio
async def test_token_list_input(self, renderer):
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
prompts = await renderer.render_completions_async(token_lists)
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=100),
)
assert len(results) == 3
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012]
assert results[2]["prompt_token_ids"] == [103, 4567]
@pytest.mark.asyncio
async def test_text_input(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.encode.return_value = [101, 7592, 2088]
renderer._async_tokenizer = mock_async_tokenizer
prompts = await renderer.render_completions_async("Hello world")
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=100),
)
assert len(results) == 1
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
mock_async_tokenizer.encode.assert_called_once()
@pytest.mark.asyncio
async def test_text_list_input(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.encode.return_value = [101, 7592, 2088]
renderer._async_tokenizer = mock_async_tokenizer
text_list_input = ["Hello world", "How are you?", "Good morning"]
prompts = await renderer.render_completions_async(text_list_input)
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=100),
)
assert len(results) == 3
for result in results:
assert result["prompt_token_ids"] == [101, 7592, 2088]
assert mock_async_tokenizer.encode.call_count == 3
@pytest.mark.asyncio
async def test_no_truncation(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.encode.return_value = [101, 7592, 2088]
renderer._async_tokenizer = mock_async_tokenizer
prompts = await renderer.render_completions_async("Hello world")
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=100),
)
assert len(results) == 1
call_args = mock_async_tokenizer.encode.call_args
assert (
"truncation" not in call_args.kwargs
or call_args.kwargs["truncation"] is False
)
@pytest.mark.asyncio
async def test_truncation_positive(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.encode.return_value = [101, 7592, 2088] # Truncated
renderer._async_tokenizer = mock_async_tokenizer
prompts = await renderer.render_completions_async("Hello world")
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(
max_total_tokens=200,
truncate_prompt_tokens=50,
),
)
assert len(results) == 1
call_args = mock_async_tokenizer.encode.call_args
assert call_args.kwargs["truncation"] is True
assert call_args.kwargs["max_length"] == 50
@pytest.mark.asyncio
async def test_truncation_negative(self, renderer, mock_async_tokenizer):
# Test that negative truncation uses model's max_model_len
mock_async_tokenizer.encode.return_value = [
101,
7592,
2088,
] # Truncated to max_model_len
renderer._async_tokenizer = mock_async_tokenizer
prompts = await renderer.render_completions_async("Hello world")
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(
max_total_tokens=200,
truncate_prompt_tokens=-1,
),
)
assert len(results) == 1
call_args = mock_async_tokenizer.encode.call_args
assert call_args.kwargs["truncation"] is True
assert call_args.kwargs["max_length"] == 200
@pytest.mark.asyncio
async def test_token_truncation_last_elements(self, renderer):
# Test that token truncation keeps the last N elements
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
prompts = await renderer.render_completions_async(long_tokens)
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(
max_total_tokens=100,
truncate_prompt_tokens=5,
),
)
assert len(results) == 1
# Should keep the last 5 tokens: [105, 106, 107, 108, 109]
assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109]
@pytest.mark.asyncio
async def test_max_length_exceeded(self, renderer):
long_tokens = list(range(150)) # Exceeds max_model_len=100
prompts = await renderer.render_completions_async(long_tokens)
with pytest.raises(ValueError, match="context length is only"):
await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=100),
)
@pytest.mark.asyncio
async def test_no_tokenizer_for_text(self, renderer):
renderer_no_tokenizer = HfRenderer.from_config(
MockModelConfig(skip_tokenizer_init=True),
tokenizer_kwargs={},
)
prompts = await renderer_no_tokenizer.render_completions_async("Hello world")
with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
await renderer_no_tokenizer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=100),
)
@pytest.mark.asyncio
async def test_token_input_with_needs_detokenization(
self, renderer, mock_async_tokenizer
):
# When needs_detokenization=True for token inputs, renderer should
# use the async tokenizer to decode and include the original text
# in the returned prompt object.
mock_async_tokenizer.decode = AsyncMock(return_value="decoded text")
renderer._async_tokenizer = mock_async_tokenizer
tokens = [1, 2, 3, 4]
prompts = await renderer.render_completions_async(tokens)
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(
max_total_tokens=renderer.config.max_model_len,
needs_detokenization=True,
),
)
assert len(results) == 1
assert results[0]["prompt_token_ids"] == tokens
assert results[0]["prompt"] == "decoded text"
mock_async_tokenizer.decode.assert_awaited_once()
class TestRenderEmbedPrompt:
def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes:
"""Helper to create base64-encoded tensor bytes"""
buffer = io.BytesIO()
torch.save(tensor, buffer)
buffer.seek(0)
return pybase64.b64encode(buffer.read())
@pytest.mark.asyncio
async def test_single_prompt_embed(self, renderer):
# Create a test tensor
test_tensor = torch.randn(10, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
)
assert len(results) == 1
assert is_embeds_prompt(results[0])
assert torch.allclose(results[0]["prompt_embeds"], test_tensor)
@pytest.mark.asyncio
async def test_multiple_prompt_embeds(self, renderer):
# Create multiple test tensors
test_tensors = [
torch.randn(8, 512, dtype=torch.float32),
torch.randn(12, 512, dtype=torch.float32),
]
embed_bytes_list = [self._create_test_embed_bytes(t) for t in test_tensors]
prompts = await renderer.render_completions_async(
prompt_embeds=embed_bytes_list
)
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
)
assert len(results) == 2
for i, result in enumerate(results):
assert is_embeds_prompt(result)
assert torch.allclose(result["prompt_embeds"], test_tensors[i])
@pytest.mark.asyncio
async def test_prompt_embed_truncation(self, renderer):
# Create tensor with more tokens than truncation limit
test_tensor = torch.randn(20, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(
max_total_tokens=renderer.config.max_model_len,
truncate_prompt_tokens=10,
),
)
assert len(results) == 1
# Should keep last 10 tokens
expected = test_tensor[-10:]
assert torch.allclose(results[0]["prompt_embeds"], expected)
@pytest.mark.asyncio
async def test_prompt_embed_different_dtypes(self, renderer):
# Test different supported dtypes
dtypes = [torch.float32, torch.float16, torch.bfloat16]
for dtype in dtypes:
test_tensor = torch.randn(5, 256, dtype=dtype)
embed_bytes = self._create_test_embed_bytes(test_tensor)
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
)
assert len(results) == 1
assert results[0]["prompt_embeds"].dtype == dtype
@pytest.mark.asyncio
async def test_prompt_embed_squeeze_batch_dim(self, renderer):
# Test tensor with batch dimension gets squeezed
test_tensor = torch.randn(1, 10, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
)
assert len(results) == 1
# Should be squeezed to 2D
assert results[0]["prompt_embeds"].shape == (10, 768)
@pytest.mark.asyncio
async def test_both_prompts_and_embeds(self, renderer, mock_async_tokenizer):
# Set up text tokenization
mock_async_tokenizer.encode.return_value = [101, 102, 103]
renderer._async_tokenizer = mock_async_tokenizer
# Create embed
test_tensor = torch.randn(5, 256, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
prompts = await renderer.render_completions_async(
"Hello world",
prompt_embeds=embed_bytes,
)
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
)
assert len(results) == 2
# First should be embed prompt
assert is_embeds_prompt(results[0])
# Second should be tokens prompt
assert "prompt_token_ids" in results[1]
assert results[1]["prompt_token_ids"] == [101, 102, 103]