feat(frontend): early-fail tokenization guard for user requests (#31366)

Signed-off-by: limingliang <limingliang@stepfun.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: limingliang <limingliang@stepfun.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Mingliang Li
2026-02-06 11:38:02 +08:00
committed by GitHub
parent 20d7454c9b
commit a32cb49b60
7 changed files with 315 additions and 209 deletions

View File

@@ -4,7 +4,6 @@
import io
from dataclasses import dataclass
from typing import Any
from unittest.mock import AsyncMock
import pybase64
import pytest
@@ -28,7 +27,6 @@ class MockModelConfig:
model: str = MODEL_NAME
tokenizer: str = MODEL_NAME
trust_remote_code: bool = False
max_model_len: int = 100
tokenizer_revision = None
tokenizer_mode = "auto"
hf_config = MockHFConfig()
@@ -37,25 +35,50 @@ class MockModelConfig:
skip_tokenizer_init: bool = False
@pytest.fixture
def mock_model_config():
return MockModelConfig()
@dataclass
class DummyTokenizer:
truncation_side: str = "left"
max_chars_per_token: int = 1
def __post_init__(self) -> None:
self._captured_encode_kwargs: dict = {}
def decode(self, tokens: list[int]):
return str(tokens)
def encode(self, text: str, **kwargs):
self._captured_encode_kwargs = kwargs
in_length = len(text)
truncation = kwargs.get("truncation")
max_length = kwargs.get("max_length")
if truncation and max_length is not None:
return list(range(min(in_length, max_length)))
return list(range(in_length))
@pytest.fixture
def mock_async_tokenizer():
return AsyncMock()
def _build_renderer(
model_config: MockModelConfig,
*,
truncation_side: str = "left",
max_chars_per_token: int = 1,
):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
@pytest.fixture
def renderer(mock_model_config):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(mock_model_config)
return HfRenderer(
mock_model_config,
renderer = HfRenderer(
model_config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
if not model_config.skip_tokenizer_init:
renderer._tokenizer = DummyTokenizer(
truncation_side=truncation_side,
max_chars_per_token=max_chars_per_token,
)
return renderer
class TestValidatePrompt:
STRING_INPUTS = [
@@ -81,39 +104,50 @@ class TestValidatePrompt:
]
# Test that a nested mixed-type list of lists raises a TypeError.
def test_empty_input(self, renderer):
def test_empty_input(self):
renderer = _build_renderer(MockModelConfig())
with pytest.raises(ValueError, match="at least one prompt"):
renderer.render_completions([])
def test_invalid_type(self, renderer):
def test_invalid_type(self):
renderer = _build_renderer(MockModelConfig())
with pytest.raises(TypeError, match="string or an array of tokens"):
renderer.render_completions([[1, 2], ["foo", "bar"]])
@pytest.mark.parametrize("string_input", STRING_INPUTS)
def test_string_consistent(self, renderer, string_input: str):
def test_string_consistent(self, string_input: str):
renderer = _build_renderer(MockModelConfig())
assert renderer.render_completions(string_input) == renderer.render_completions(
[string_input]
)
@pytest.mark.parametrize("token_input", TOKEN_INPUTS)
def test_token_consistent(self, renderer, token_input: list[int]):
def test_token_consistent(self, token_input: list[int]):
renderer = _build_renderer(MockModelConfig())
assert renderer.render_completions(token_input) == renderer.render_completions(
[token_input]
)
@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES)
def test_string_slice(self, renderer, inputs_slice: slice):
def test_string_slice(self, inputs_slice: slice):
renderer = _build_renderer(MockModelConfig())
assert renderer.render_completions(self.STRING_INPUTS)[
inputs_slice
] == renderer.render_completions(self.STRING_INPUTS[inputs_slice])
class TestRenderPrompt:
@pytest.mark.asyncio
async def test_token_input(self, renderer):
def test_token_input(self):
renderer = _build_renderer(MockModelConfig())
tokens = [101, 7592, 2088]
prompts = await renderer.render_completions_async(tokens)
results = await renderer.tokenize_prompts_async(
prompts = renderer.render_completions(tokens)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
)
@@ -121,11 +155,12 @@ class TestRenderPrompt:
assert len(results) == 1
assert results[0]["prompt_token_ids"] == tokens
@pytest.mark.asyncio
async def test_token_list_input(self, renderer):
def test_token_list_input(self):
renderer = _build_renderer(MockModelConfig())
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
prompts = await renderer.render_completions_async(token_lists)
results = await renderer.tokenize_prompts_async(
prompts = renderer.render_completions(token_lists)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
)
@@ -135,167 +170,178 @@ class TestRenderPrompt:
assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012]
assert results[2]["prompt_token_ids"] == [103, 4567]
@pytest.mark.asyncio
async def test_text_input(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.encode.return_value = [101, 7592, 2088]
renderer._async_tokenizer = mock_async_tokenizer
def test_text_input(self):
renderer = _build_renderer(MockModelConfig())
prompts = await renderer.render_completions_async("Hello world")
results = await renderer.tokenize_prompts_async(
text_input = "x" * 10
prompts = renderer.render_completions(text_input)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
)
assert len(results) == 1
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
mock_async_tokenizer.encode.assert_called_once()
assert len(results[0]["prompt_token_ids"]) == 10
@pytest.mark.asyncio
async def test_text_list_input(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.encode.return_value = [101, 7592, 2088]
renderer._async_tokenizer = mock_async_tokenizer
def test_text_list_input(self):
renderer = _build_renderer(MockModelConfig())
text_list_input = ["Hello world", "How are you?", "Good morning"]
prompts = await renderer.render_completions_async(text_list_input)
results = await renderer.tokenize_prompts_async(
text_list_input = ["x" * 10, "x" * 12, "x" * 14]
prompts = renderer.render_completions(text_list_input)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
)
assert len(results) == 3
for result in results:
assert result["prompt_token_ids"] == [101, 7592, 2088]
assert mock_async_tokenizer.encode.call_count == 3
for text_input, result in zip(text_list_input, results):
assert len(result["prompt_token_ids"]) == len(text_input)
@pytest.mark.asyncio
async def test_no_truncation(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.encode.return_value = [101, 7592, 2088]
renderer._async_tokenizer = mock_async_tokenizer
def test_zero_truncation(self):
renderer = _build_renderer(MockModelConfig())
prompts = await renderer.render_completions_async("Hello world")
results = await renderer.tokenize_prompts_async(
prompts = renderer.render_completions("x" * 200)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=0),
)
assert len(results) == 1
call_args = mock_async_tokenizer.encode.call_args
assert (
"truncation" not in call_args.kwargs
or call_args.kwargs["truncation"] is False
)
assert len(results[0]["prompt_token_ids"]) == 0
@pytest.mark.asyncio
async def test_truncation_positive(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.encode.return_value = [101, 7592, 2088] # Truncated
renderer._async_tokenizer = mock_async_tokenizer
def test_pos_truncation(self):
renderer = _build_renderer(MockModelConfig())
prompts = await renderer.render_completions_async("Hello world")
results = await renderer.tokenize_prompts_async(
prompts = renderer.render_completions("x" * 200)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(
max_total_tokens=200,
truncate_prompt_tokens=50,
),
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=50),
)
assert len(results) == 1
call_args = mock_async_tokenizer.encode.call_args
assert call_args.kwargs["truncation"] is True
assert call_args.kwargs["max_length"] == 50
assert len(results[0]["prompt_token_ids"]) == 50
@pytest.mark.asyncio
async def test_truncation_negative(self, renderer, mock_async_tokenizer):
# Test that negative truncation uses model's max_model_len
mock_async_tokenizer.encode.return_value = [
101,
7592,
2088,
] # Truncated to max_model_len
renderer._async_tokenizer = mock_async_tokenizer
def test_neg_truncation(self):
renderer = _build_renderer(MockModelConfig())
prompts = await renderer.render_completions_async("Hello world")
results = await renderer.tokenize_prompts_async(
prompts = renderer.render_completions("x" * 200)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(
max_total_tokens=200,
truncate_prompt_tokens=-1,
),
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=-1),
)
assert len(results) == 1
call_args = mock_async_tokenizer.encode.call_args
assert call_args.kwargs["truncation"] is True
assert call_args.kwargs["max_length"] == 200
assert len(results[0]["prompt_token_ids"]) == 100 # max_total_tokens
def test_truncation_left(self):
renderer = _build_renderer(MockModelConfig(), truncation_side="left")
@pytest.mark.asyncio
async def test_token_truncation_last_elements(self, renderer):
# Test that token truncation keeps the last N elements
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
prompts = await renderer.render_completions_async(long_tokens)
results = await renderer.tokenize_prompts_async(
prompts = renderer.render_completions(long_tokens)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(
max_total_tokens=100,
truncate_prompt_tokens=5,
),
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
)
assert len(results) == 1
# Should keep the last 5 tokens: [105, 106, 107, 108, 109]
assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109]
@pytest.mark.asyncio
async def test_max_length_exceeded(self, renderer):
long_tokens = list(range(150)) # Exceeds max_model_len=100
def test_truncation_right(self):
renderer = _build_renderer(MockModelConfig(), truncation_side="right")
prompts = await renderer.render_completions_async(long_tokens)
with pytest.raises(ValueError, match="context length is only"):
await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=100),
)
@pytest.mark.asyncio
async def test_no_tokenizer_for_text(self, renderer):
renderer_no_tokenizer = HfRenderer.from_config(
MockModelConfig(skip_tokenizer_init=True),
tokenizer_kwargs={},
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
prompts = renderer.render_completions(long_tokens)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
)
prompts = await renderer_no_tokenizer.render_completions_async("Hello world")
assert len(results) == 1
# Should keep the first 5 tokens: [100, 101, 102, 103, 104]
assert results[0]["prompt_token_ids"] == [100, 101, 102, 103, 104]
with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
await renderer_no_tokenizer.tokenize_prompts_async(
def test_text_max_length_exceeded_obvious(self):
renderer = _build_renderer(MockModelConfig(), max_chars_per_token=1)
# Exceeds max_total_tokens and max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
long_tokens = "x" * 150
prompts = renderer.render_completions(long_tokens)
with pytest.raises(
ValueError,
match="input characters and requested .* context length is only",
):
renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
)
@pytest.mark.asyncio
async def test_token_input_with_needs_detokenization(
self, renderer, mock_async_tokenizer
):
# When needs_detokenization=True for token inputs, renderer should
# use the async tokenizer to decode and include the original text
# in the returned prompt object.
mock_async_tokenizer.decode = AsyncMock(return_value="decoded text")
renderer._async_tokenizer = mock_async_tokenizer
# Should not even attempt tokenization
assert renderer._tokenizer._captured_encode_kwargs == {}
def test_text_max_length_exceeded_nonobvious(self):
renderer = _build_renderer(MockModelConfig(), max_chars_per_token=2)
# Exceeds max_total_tokens but not max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
long_tokens = "x" * 150
prompts = renderer.render_completions(long_tokens)
with pytest.raises(
ValueError,
match="input tokens and requested .* context length is only",
):
renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
)
# Should only tokenize the first max_total_tokens + 1 tokens
assert renderer._tokenizer._captured_encode_kwargs["truncation"] is True
assert renderer._tokenizer._captured_encode_kwargs["max_length"] == 101
def test_token_max_length_exceeded(self):
renderer = _build_renderer(MockModelConfig())
long_tokens = list(range(150)) # Exceeds max_total_tokens=100
prompts = renderer.render_completions(long_tokens)
with pytest.raises(
ValueError,
match="input tokens and requested .* context length is only",
):
renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=None),
)
def test_no_tokenizer_for_text(self):
renderer = _build_renderer(MockModelConfig(skip_tokenizer_init=True))
prompts = renderer.render_completions("Hello world")
with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
)
def test_token_input_with_needs_detokenization(self):
renderer = _build_renderer(MockModelConfig())
tokens = [1, 2, 3, 4]
prompts = await renderer.render_completions_async(tokens)
results = await renderer.tokenize_prompts_async(
prompts = renderer.render_completions(tokens)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(
max_total_tokens=renderer.config.max_model_len,
max_total_tokens=100,
needs_detokenization=True,
),
)
assert len(results) == 1
assert results[0]["prompt_token_ids"] == tokens
assert results[0]["prompt"] == "decoded text"
mock_async_tokenizer.decode.assert_awaited_once()
assert results[0]["prompt"] == "[1, 2, 3, 4]"
class TestRenderEmbedPrompt:
@@ -306,118 +352,121 @@ class TestRenderEmbedPrompt:
buffer.seek(0)
return pybase64.b64encode(buffer.read())
@pytest.mark.asyncio
async def test_single_prompt_embed(self, renderer):
# Create a test tensor
test_tensor = torch.randn(10, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
def test_single_prompt_embed(self):
renderer = _build_renderer(MockModelConfig())
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
results = await renderer.tokenize_prompts_async(
# Create a test tensor
tensor_input = torch.randn(10, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(tensor_input)
prompts = renderer.render_completions(prompt_embeds=embed_bytes)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
TokenizeParams(max_total_tokens=100),
)
assert len(results) == 1
assert torch.allclose(results[0]["prompt_embeds"], test_tensor)
assert torch.equal(results[0]["prompt_embeds"], tensor_input)
def test_multiple_prompt_embeds(self):
renderer = _build_renderer(MockModelConfig())
@pytest.mark.asyncio
async def test_multiple_prompt_embeds(self, renderer):
# Create multiple test tensors
test_tensors = [
tensor_inputs = [
torch.randn(8, 512, dtype=torch.float32),
torch.randn(12, 512, dtype=torch.float32),
]
embed_bytes_list = [self._create_test_embed_bytes(t) for t in test_tensors]
prompts = await renderer.render_completions_async(
prompt_embeds=embed_bytes_list
prompts = renderer.render_completions(
prompt_embeds=[self._create_test_embed_bytes(t) for t in tensor_inputs],
)
results = await renderer.tokenize_prompts_async(
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
TokenizeParams(max_total_tokens=100),
)
assert len(results) == 2
for i, result in enumerate(results):
assert torch.allclose(result["prompt_embeds"], test_tensors[i])
assert torch.allclose(result["prompt_embeds"], tensor_inputs[i])
def test_prompt_embed_truncation(self):
renderer = _build_renderer(MockModelConfig())
@pytest.mark.asyncio
async def test_prompt_embed_truncation(self, renderer):
# Create tensor with more tokens than truncation limit
test_tensor = torch.randn(20, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
tensor_input = torch.randn(20, 768, dtype=torch.float32)
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
results = await renderer.tokenize_prompts_async(
prompts = renderer.render_completions(
prompt_embeds=self._create_test_embed_bytes(tensor_input),
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(
max_total_tokens=renderer.config.max_model_len,
max_total_tokens=100,
truncate_prompt_tokens=10,
),
)
assert len(results) == 1
# Should keep last 10 tokens
expected = test_tensor[-10:]
assert torch.allclose(results[0]["prompt_embeds"], expected)
expected = tensor_input[-10:]
assert torch.equal(results[0]["prompt_embeds"], expected)
def test_prompt_embed_different_dtypes(self):
renderer = _build_renderer(MockModelConfig())
@pytest.mark.asyncio
async def test_prompt_embed_different_dtypes(self, renderer):
# Test different supported dtypes
dtypes = [torch.float32, torch.float16, torch.bfloat16]
for dtype in dtypes:
test_tensor = torch.randn(5, 256, dtype=dtype)
embed_bytes = self._create_test_embed_bytes(test_tensor)
tensor_input = torch.randn(5, 256, dtype=dtype)
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
results = await renderer.tokenize_prompts_async(
prompts = renderer.render_completions(
prompt_embeds=self._create_test_embed_bytes(tensor_input),
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
TokenizeParams(max_total_tokens=100),
)
assert len(results) == 1
assert results[0]["prompt_embeds"].dtype == dtype
@pytest.mark.asyncio
async def test_prompt_embed_squeeze_batch_dim(self, renderer):
# Test tensor with batch dimension gets squeezed
test_tensor = torch.randn(1, 10, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
def test_prompt_embed_squeeze_batch_dim(self):
renderer = _build_renderer(MockModelConfig())
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
results = await renderer.tokenize_prompts_async(
# Test tensor with batch dimension gets squeezed
tensor_input = torch.randn(1, 10, 768, dtype=torch.float32)
prompts = renderer.render_completions(
prompt_embeds=self._create_test_embed_bytes(tensor_input),
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
TokenizeParams(max_total_tokens=100),
)
assert len(results) == 1
# Should be squeezed to 2D
assert results[0]["prompt_embeds"].shape == (10, 768)
@pytest.mark.asyncio
async def test_both_prompts_and_embeds(self, renderer, mock_async_tokenizer):
# Set up text tokenization
mock_async_tokenizer.encode.return_value = [101, 102, 103]
renderer._async_tokenizer = mock_async_tokenizer
def test_both_prompts_and_embeds(self):
renderer = _build_renderer(MockModelConfig())
# Create embed
test_tensor = torch.randn(5, 256, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
text_input = "Hello world"
tensor_input = torch.randn(5, 256, dtype=torch.float32)
prompts = await renderer.render_completions_async(
"Hello world",
prompt_embeds=embed_bytes,
prompts = renderer.render_completions(
text_input,
prompt_embeds=self._create_test_embed_bytes(tensor_input),
)
results = await renderer.tokenize_prompts_async(
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
TokenizeParams(max_total_tokens=100),
)
assert len(results) == 2
# First should be embed prompt
assert torch.allclose(results[0]["prompt_embeds"], test_tensor)
assert torch.equal(results[0]["prompt_embeds"], tensor_input)
# Second should be tokens prompt
assert "prompt_token_ids" in results[1]
assert results[1]["prompt_token_ids"] == [101, 102, 103]
assert len(results[1]["prompt_token_ids"]) == len(text_input)