feat(frontend): early-fail tokenization guard for user requests (#31366)
Signed-off-by: limingliang <limingliang@stepfun.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: limingliang <limingliang@stepfun.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -4,7 +4,6 @@
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pybase64
|
||||
import pytest
|
||||
@@ -28,7 +27,6 @@ class MockModelConfig:
|
||||
model: str = MODEL_NAME
|
||||
tokenizer: str = MODEL_NAME
|
||||
trust_remote_code: bool = False
|
||||
max_model_len: int = 100
|
||||
tokenizer_revision = None
|
||||
tokenizer_mode = "auto"
|
||||
hf_config = MockHFConfig()
|
||||
@@ -37,25 +35,50 @@ class MockModelConfig:
|
||||
skip_tokenizer_init: bool = False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_config():
|
||||
return MockModelConfig()
|
||||
@dataclass
|
||||
class DummyTokenizer:
|
||||
truncation_side: str = "left"
|
||||
max_chars_per_token: int = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._captured_encode_kwargs: dict = {}
|
||||
|
||||
def decode(self, tokens: list[int]):
|
||||
return str(tokens)
|
||||
|
||||
def encode(self, text: str, **kwargs):
|
||||
self._captured_encode_kwargs = kwargs
|
||||
|
||||
in_length = len(text)
|
||||
truncation = kwargs.get("truncation")
|
||||
max_length = kwargs.get("max_length")
|
||||
if truncation and max_length is not None:
|
||||
return list(range(min(in_length, max_length)))
|
||||
|
||||
return list(range(in_length))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_tokenizer():
|
||||
return AsyncMock()
|
||||
def _build_renderer(
|
||||
model_config: MockModelConfig,
|
||||
*,
|
||||
truncation_side: str = "left",
|
||||
max_chars_per_token: int = 1,
|
||||
):
|
||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def renderer(mock_model_config):
|
||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(mock_model_config)
|
||||
|
||||
return HfRenderer(
|
||||
mock_model_config,
|
||||
renderer = HfRenderer(
|
||||
model_config,
|
||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||
)
|
||||
|
||||
if not model_config.skip_tokenizer_init:
|
||||
renderer._tokenizer = DummyTokenizer(
|
||||
truncation_side=truncation_side,
|
||||
max_chars_per_token=max_chars_per_token,
|
||||
)
|
||||
|
||||
return renderer
|
||||
|
||||
|
||||
class TestValidatePrompt:
|
||||
STRING_INPUTS = [
|
||||
@@ -81,39 +104,50 @@ class TestValidatePrompt:
|
||||
]
|
||||
|
||||
# Test that a nested mixed-type list of lists raises a TypeError.
|
||||
def test_empty_input(self, renderer):
|
||||
def test_empty_input(self):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
with pytest.raises(ValueError, match="at least one prompt"):
|
||||
renderer.render_completions([])
|
||||
|
||||
def test_invalid_type(self, renderer):
|
||||
def test_invalid_type(self):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
with pytest.raises(TypeError, match="string or an array of tokens"):
|
||||
renderer.render_completions([[1, 2], ["foo", "bar"]])
|
||||
|
||||
@pytest.mark.parametrize("string_input", STRING_INPUTS)
|
||||
def test_string_consistent(self, renderer, string_input: str):
|
||||
def test_string_consistent(self, string_input: str):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
assert renderer.render_completions(string_input) == renderer.render_completions(
|
||||
[string_input]
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("token_input", TOKEN_INPUTS)
|
||||
def test_token_consistent(self, renderer, token_input: list[int]):
|
||||
def test_token_consistent(self, token_input: list[int]):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
assert renderer.render_completions(token_input) == renderer.render_completions(
|
||||
[token_input]
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES)
|
||||
def test_string_slice(self, renderer, inputs_slice: slice):
|
||||
def test_string_slice(self, inputs_slice: slice):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
assert renderer.render_completions(self.STRING_INPUTS)[
|
||||
inputs_slice
|
||||
] == renderer.render_completions(self.STRING_INPUTS[inputs_slice])
|
||||
|
||||
|
||||
class TestRenderPrompt:
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_input(self, renderer):
|
||||
def test_token_input(self):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
tokens = [101, 7592, 2088]
|
||||
prompts = await renderer.render_completions_async(tokens)
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts = renderer.render_completions(tokens)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
)
|
||||
@@ -121,11 +155,12 @@ class TestRenderPrompt:
|
||||
assert len(results) == 1
|
||||
assert results[0]["prompt_token_ids"] == tokens
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_list_input(self, renderer):
|
||||
def test_token_list_input(self):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
|
||||
prompts = await renderer.render_completions_async(token_lists)
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts = renderer.render_completions(token_lists)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
)
|
||||
@@ -135,167 +170,178 @@ class TestRenderPrompt:
|
||||
assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012]
|
||||
assert results[2]["prompt_token_ids"] == [103, 4567]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_input(self, renderer, mock_async_tokenizer):
|
||||
mock_async_tokenizer.encode.return_value = [101, 7592, 2088]
|
||||
renderer._async_tokenizer = mock_async_tokenizer
|
||||
def test_text_input(self):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
prompts = await renderer.render_completions_async("Hello world")
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
text_input = "x" * 10
|
||||
prompts = renderer.render_completions(text_input)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
|
||||
mock_async_tokenizer.encode.assert_called_once()
|
||||
assert len(results[0]["prompt_token_ids"]) == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_list_input(self, renderer, mock_async_tokenizer):
|
||||
mock_async_tokenizer.encode.return_value = [101, 7592, 2088]
|
||||
renderer._async_tokenizer = mock_async_tokenizer
|
||||
def test_text_list_input(self):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
text_list_input = ["Hello world", "How are you?", "Good morning"]
|
||||
prompts = await renderer.render_completions_async(text_list_input)
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
text_list_input = ["x" * 10, "x" * 12, "x" * 14]
|
||||
prompts = renderer.render_completions(text_list_input)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
)
|
||||
|
||||
assert len(results) == 3
|
||||
for result in results:
|
||||
assert result["prompt_token_ids"] == [101, 7592, 2088]
|
||||
assert mock_async_tokenizer.encode.call_count == 3
|
||||
for text_input, result in zip(text_list_input, results):
|
||||
assert len(result["prompt_token_ids"]) == len(text_input)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_truncation(self, renderer, mock_async_tokenizer):
|
||||
mock_async_tokenizer.encode.return_value = [101, 7592, 2088]
|
||||
renderer._async_tokenizer = mock_async_tokenizer
|
||||
def test_zero_truncation(self):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
prompts = await renderer.render_completions_async("Hello world")
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts = renderer.render_completions("x" * 200)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=0),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
call_args = mock_async_tokenizer.encode.call_args
|
||||
assert (
|
||||
"truncation" not in call_args.kwargs
|
||||
or call_args.kwargs["truncation"] is False
|
||||
)
|
||||
assert len(results[0]["prompt_token_ids"]) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncation_positive(self, renderer, mock_async_tokenizer):
|
||||
mock_async_tokenizer.encode.return_value = [101, 7592, 2088] # Truncated
|
||||
renderer._async_tokenizer = mock_async_tokenizer
|
||||
def test_pos_truncation(self):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
prompts = await renderer.render_completions_async("Hello world")
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts = renderer.render_completions("x" * 200)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(
|
||||
max_total_tokens=200,
|
||||
truncate_prompt_tokens=50,
|
||||
),
|
||||
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=50),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
call_args = mock_async_tokenizer.encode.call_args
|
||||
assert call_args.kwargs["truncation"] is True
|
||||
assert call_args.kwargs["max_length"] == 50
|
||||
assert len(results[0]["prompt_token_ids"]) == 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncation_negative(self, renderer, mock_async_tokenizer):
|
||||
# Test that negative truncation uses model's max_model_len
|
||||
mock_async_tokenizer.encode.return_value = [
|
||||
101,
|
||||
7592,
|
||||
2088,
|
||||
] # Truncated to max_model_len
|
||||
renderer._async_tokenizer = mock_async_tokenizer
|
||||
def test_neg_truncation(self):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
prompts = await renderer.render_completions_async("Hello world")
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts = renderer.render_completions("x" * 200)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(
|
||||
max_total_tokens=200,
|
||||
truncate_prompt_tokens=-1,
|
||||
),
|
||||
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=-1),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
call_args = mock_async_tokenizer.encode.call_args
|
||||
assert call_args.kwargs["truncation"] is True
|
||||
assert call_args.kwargs["max_length"] == 200
|
||||
assert len(results[0]["prompt_token_ids"]) == 100 # max_total_tokens
|
||||
|
||||
def test_truncation_left(self):
|
||||
renderer = _build_renderer(MockModelConfig(), truncation_side="left")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_truncation_last_elements(self, renderer):
|
||||
# Test that token truncation keeps the last N elements
|
||||
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
|
||||
prompts = await renderer.render_completions_async(long_tokens)
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts = renderer.render_completions(long_tokens)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(
|
||||
max_total_tokens=100,
|
||||
truncate_prompt_tokens=5,
|
||||
),
|
||||
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
# Should keep the last 5 tokens: [105, 106, 107, 108, 109]
|
||||
assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_length_exceeded(self, renderer):
|
||||
long_tokens = list(range(150)) # Exceeds max_model_len=100
|
||||
def test_truncation_right(self):
|
||||
renderer = _build_renderer(MockModelConfig(), truncation_side="right")
|
||||
|
||||
prompts = await renderer.render_completions_async(long_tokens)
|
||||
|
||||
with pytest.raises(ValueError, match="context length is only"):
|
||||
await renderer.tokenize_prompts_async(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tokenizer_for_text(self, renderer):
|
||||
renderer_no_tokenizer = HfRenderer.from_config(
|
||||
MockModelConfig(skip_tokenizer_init=True),
|
||||
tokenizer_kwargs={},
|
||||
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
|
||||
prompts = renderer.render_completions(long_tokens)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
|
||||
)
|
||||
|
||||
prompts = await renderer_no_tokenizer.render_completions_async("Hello world")
|
||||
assert len(results) == 1
|
||||
# Should keep the first 5 tokens: [100, 101, 102, 103, 104]
|
||||
assert results[0]["prompt_token_ids"] == [100, 101, 102, 103, 104]
|
||||
|
||||
with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
|
||||
await renderer_no_tokenizer.tokenize_prompts_async(
|
||||
def test_text_max_length_exceeded_obvious(self):
|
||||
renderer = _build_renderer(MockModelConfig(), max_chars_per_token=1)
|
||||
|
||||
# Exceeds max_total_tokens and max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
|
||||
long_tokens = "x" * 150
|
||||
prompts = renderer.render_completions(long_tokens)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="input characters and requested .* context length is only",
|
||||
):
|
||||
renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_input_with_needs_detokenization(
|
||||
self, renderer, mock_async_tokenizer
|
||||
):
|
||||
# When needs_detokenization=True for token inputs, renderer should
|
||||
# use the async tokenizer to decode and include the original text
|
||||
# in the returned prompt object.
|
||||
mock_async_tokenizer.decode = AsyncMock(return_value="decoded text")
|
||||
renderer._async_tokenizer = mock_async_tokenizer
|
||||
# Should not even attempt tokenization
|
||||
assert renderer._tokenizer._captured_encode_kwargs == {}
|
||||
|
||||
def test_text_max_length_exceeded_nonobvious(self):
|
||||
renderer = _build_renderer(MockModelConfig(), max_chars_per_token=2)
|
||||
|
||||
# Exceeds max_total_tokens but not max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
|
||||
long_tokens = "x" * 150
|
||||
prompts = renderer.render_completions(long_tokens)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="input tokens and requested .* context length is only",
|
||||
):
|
||||
renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
)
|
||||
|
||||
# Should only tokenize the first max_total_tokens + 1 tokens
|
||||
assert renderer._tokenizer._captured_encode_kwargs["truncation"] is True
|
||||
assert renderer._tokenizer._captured_encode_kwargs["max_length"] == 101
|
||||
|
||||
def test_token_max_length_exceeded(self):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
long_tokens = list(range(150)) # Exceeds max_total_tokens=100
|
||||
prompts = renderer.render_completions(long_tokens)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="input tokens and requested .* context length is only",
|
||||
):
|
||||
renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=None),
|
||||
)
|
||||
|
||||
def test_no_tokenizer_for_text(self):
|
||||
renderer = _build_renderer(MockModelConfig(skip_tokenizer_init=True))
|
||||
|
||||
prompts = renderer.render_completions("Hello world")
|
||||
|
||||
with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
|
||||
renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
)
|
||||
|
||||
def test_token_input_with_needs_detokenization(self):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
tokens = [1, 2, 3, 4]
|
||||
prompts = await renderer.render_completions_async(tokens)
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts = renderer.render_completions(tokens)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(
|
||||
max_total_tokens=renderer.config.max_model_len,
|
||||
max_total_tokens=100,
|
||||
needs_detokenization=True,
|
||||
),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["prompt_token_ids"] == tokens
|
||||
assert results[0]["prompt"] == "decoded text"
|
||||
mock_async_tokenizer.decode.assert_awaited_once()
|
||||
assert results[0]["prompt"] == "[1, 2, 3, 4]"
|
||||
|
||||
|
||||
class TestRenderEmbedPrompt:
|
||||
@@ -306,118 +352,121 @@ class TestRenderEmbedPrompt:
|
||||
buffer.seek(0)
|
||||
return pybase64.b64encode(buffer.read())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_prompt_embed(self, renderer):
|
||||
# Create a test tensor
|
||||
test_tensor = torch.randn(10, 768, dtype=torch.float32)
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
def test_single_prompt_embed(self):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
# Create a test tensor
|
||||
tensor_input = torch.randn(10, 768, dtype=torch.float32)
|
||||
embed_bytes = self._create_test_embed_bytes(tensor_input)
|
||||
|
||||
prompts = renderer.render_completions(prompt_embeds=embed_bytes)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert torch.allclose(results[0]["prompt_embeds"], test_tensor)
|
||||
assert torch.equal(results[0]["prompt_embeds"], tensor_input)
|
||||
|
||||
def test_multiple_prompt_embeds(self):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_prompt_embeds(self, renderer):
|
||||
# Create multiple test tensors
|
||||
test_tensors = [
|
||||
tensor_inputs = [
|
||||
torch.randn(8, 512, dtype=torch.float32),
|
||||
torch.randn(12, 512, dtype=torch.float32),
|
||||
]
|
||||
embed_bytes_list = [self._create_test_embed_bytes(t) for t in test_tensors]
|
||||
|
||||
prompts = await renderer.render_completions_async(
|
||||
prompt_embeds=embed_bytes_list
|
||||
prompts = renderer.render_completions(
|
||||
prompt_embeds=[self._create_test_embed_bytes(t) for t in tensor_inputs],
|
||||
)
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
for i, result in enumerate(results):
|
||||
assert torch.allclose(result["prompt_embeds"], test_tensors[i])
|
||||
assert torch.allclose(result["prompt_embeds"], tensor_inputs[i])
|
||||
|
||||
def test_prompt_embed_truncation(self):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_embed_truncation(self, renderer):
|
||||
# Create tensor with more tokens than truncation limit
|
||||
test_tensor = torch.randn(20, 768, dtype=torch.float32)
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
tensor_input = torch.randn(20, 768, dtype=torch.float32)
|
||||
|
||||
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts = renderer.render_completions(
|
||||
prompt_embeds=self._create_test_embed_bytes(tensor_input),
|
||||
)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(
|
||||
max_total_tokens=renderer.config.max_model_len,
|
||||
max_total_tokens=100,
|
||||
truncate_prompt_tokens=10,
|
||||
),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
# Should keep last 10 tokens
|
||||
expected = test_tensor[-10:]
|
||||
assert torch.allclose(results[0]["prompt_embeds"], expected)
|
||||
expected = tensor_input[-10:]
|
||||
assert torch.equal(results[0]["prompt_embeds"], expected)
|
||||
|
||||
def test_prompt_embed_different_dtypes(self):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_embed_different_dtypes(self, renderer):
|
||||
# Test different supported dtypes
|
||||
dtypes = [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
for dtype in dtypes:
|
||||
test_tensor = torch.randn(5, 256, dtype=dtype)
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
tensor_input = torch.randn(5, 256, dtype=dtype)
|
||||
|
||||
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts = renderer.render_completions(
|
||||
prompt_embeds=self._create_test_embed_bytes(tensor_input),
|
||||
)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["prompt_embeds"].dtype == dtype
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_embed_squeeze_batch_dim(self, renderer):
|
||||
# Test tensor with batch dimension gets squeezed
|
||||
test_tensor = torch.randn(1, 10, 768, dtype=torch.float32)
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
def test_prompt_embed_squeeze_batch_dim(self):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
# Test tensor with batch dimension gets squeezed
|
||||
tensor_input = torch.randn(1, 10, 768, dtype=torch.float32)
|
||||
|
||||
prompts = renderer.render_completions(
|
||||
prompt_embeds=self._create_test_embed_bytes(tensor_input),
|
||||
)
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
# Should be squeezed to 2D
|
||||
assert results[0]["prompt_embeds"].shape == (10, 768)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_both_prompts_and_embeds(self, renderer, mock_async_tokenizer):
|
||||
# Set up text tokenization
|
||||
mock_async_tokenizer.encode.return_value = [101, 102, 103]
|
||||
renderer._async_tokenizer = mock_async_tokenizer
|
||||
def test_both_prompts_and_embeds(self):
|
||||
renderer = _build_renderer(MockModelConfig())
|
||||
|
||||
# Create embed
|
||||
test_tensor = torch.randn(5, 256, dtype=torch.float32)
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
text_input = "Hello world"
|
||||
tensor_input = torch.randn(5, 256, dtype=torch.float32)
|
||||
|
||||
prompts = await renderer.render_completions_async(
|
||||
"Hello world",
|
||||
prompt_embeds=embed_bytes,
|
||||
prompts = renderer.render_completions(
|
||||
text_input,
|
||||
prompt_embeds=self._create_test_embed_bytes(tensor_input),
|
||||
)
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
results = renderer.tokenize_prompts(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
# First should be embed prompt
|
||||
assert torch.allclose(results[0]["prompt_embeds"], test_tensor)
|
||||
assert torch.equal(results[0]["prompt_embeds"], tensor_input)
|
||||
# Second should be tokens prompt
|
||||
assert "prompt_token_ids" in results[1]
|
||||
assert results[1]["prompt_token_ids"] == [101, 102, 103]
|
||||
assert len(results[1]["prompt_token_ids"]) == len(text_input)
|
||||
|
||||
@@ -229,23 +229,53 @@ class TokenizeParams:
|
||||
max_length = self.truncate_prompt_tokens
|
||||
if max_length is not None and max_length < 0:
|
||||
max_length = self.max_input_tokens
|
||||
elif max_length is None and self.max_input_tokens is not None:
|
||||
# This prevents tokenization from taking up more resources than necessary
|
||||
# while still failing `self._token_len_check` as expected by users
|
||||
max_length = self.max_input_tokens + 1
|
||||
|
||||
return dict(
|
||||
truncation=self.truncate_prompt_tokens is not None,
|
||||
truncation=max_length is not None,
|
||||
max_length=max_length,
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
)
|
||||
|
||||
def _apply_lowercase(self, tokenizer: TokenizerLike | None, text: str) -> str:
|
||||
if self.do_lower_case:
|
||||
text = text.lower()
|
||||
def _text_len_check(self, tokenizer: TokenizerLike | None, text: str) -> str:
|
||||
"""Apply length checks to prompt text if necessary."""
|
||||
max_input_tokens = self.max_input_tokens
|
||||
if max_input_tokens is None:
|
||||
return text
|
||||
|
||||
if self.truncate_prompt_tokens is None and tokenizer is not None:
|
||||
max_input_chars = max_input_tokens * tokenizer.max_chars_per_token
|
||||
|
||||
if len(text) > max_input_chars:
|
||||
# To save resources, fail the request outright without even
|
||||
# attempting tokenization
|
||||
raise VLLMValidationError(
|
||||
f"You passed {len(text)} input characters "
|
||||
f"and requested {self.max_output_tokens} output tokens. "
|
||||
f"However, the model's context length is only "
|
||||
f"{self.max_total_tokens} tokens, resulting in a maximum "
|
||||
f"input length of {max_input_tokens} tokens "
|
||||
f"(at most {max_input_chars} characters). "
|
||||
f"Please reduce the length of the input prompt.",
|
||||
parameter="input_text",
|
||||
value=len(text),
|
||||
)
|
||||
|
||||
return text
|
||||
|
||||
def _text_lowercase(self, tokenizer: TokenizerLike | None, text: str) -> str:
|
||||
"""Apply lowercase to prompt text if necessary."""
|
||||
return text.lower() if self.do_lower_case else text
|
||||
|
||||
def _validate_text(self, tokenizer: TokenizerLike | None, text: str) -> str:
|
||||
"""Apply all validators to prompt text."""
|
||||
# TODO: Implement https://github.com/vllm-project/vllm/pull/31366
|
||||
for validator in (self._apply_lowercase,):
|
||||
for validator in (
|
||||
self._text_len_check,
|
||||
self._text_lowercase,
|
||||
):
|
||||
text = validator(tokenizer, text)
|
||||
|
||||
return text
|
||||
@@ -265,8 +295,8 @@ class TokenizeParams:
|
||||
|
||||
return prompt
|
||||
|
||||
def _apply_padding(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
|
||||
"""Apply padding to a token sequence."""
|
||||
def _token_padding(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
|
||||
"""Apply padding to prompt tokens if necessary."""
|
||||
pad_length = self.pad_prompt_tokens
|
||||
if pad_length is not None and pad_length < 0:
|
||||
pad_length = self.max_input_tokens
|
||||
@@ -281,8 +311,8 @@ class TokenizeParams:
|
||||
|
||||
return tokens + [tokenizer.pad_token_id] * (pad_length - len(tokens))
|
||||
|
||||
def _apply_truncation(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
|
||||
"""Apply truncation to a token sequence."""
|
||||
def _token_truncation(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
|
||||
"""Apply truncation to prompt tokens if necessary."""
|
||||
max_length = self.truncate_prompt_tokens
|
||||
if max_length is not None and max_length < 0:
|
||||
max_length = self.max_input_tokens
|
||||
@@ -297,18 +327,20 @@ class TokenizeParams:
|
||||
|
||||
return tokens[:max_length]
|
||||
|
||||
def _apply_length_check(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
|
||||
"""Apply length checks to a token sequence."""
|
||||
def _token_len_check(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
|
||||
"""Apply length checks to prompt tokens if necessary."""
|
||||
max_input_tokens = self.max_input_tokens
|
||||
if max_input_tokens is None:
|
||||
return tokens
|
||||
|
||||
if max_input_tokens is not None and len(tokens) > max_input_tokens:
|
||||
if len(tokens) > max_input_tokens:
|
||||
raise VLLMValidationError(
|
||||
f"You passed {len(tokens)} input tokens and "
|
||||
f"requested {self.max_output_tokens} output tokens. "
|
||||
f"You passed {len(tokens)} input tokens "
|
||||
f"and requested {self.max_output_tokens} output tokens. "
|
||||
f"However, the model's context length is only "
|
||||
f"{self.max_total_tokens}, resulting in a maximum "
|
||||
f"input length of {max_input_tokens}. "
|
||||
f"Please reduce the length of the input messages.",
|
||||
f"{self.max_total_tokens} tokens, resulting in a maximum "
|
||||
f"input length of {max_input_tokens} tokens. "
|
||||
f"Please reduce the length of the input prompt.",
|
||||
parameter="input_tokens",
|
||||
value=len(tokens),
|
||||
)
|
||||
@@ -318,9 +350,9 @@ class TokenizeParams:
|
||||
def _validate_tokens(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
|
||||
"""Apply all validators to a token sequence."""
|
||||
for validator in (
|
||||
self._apply_padding,
|
||||
self._apply_truncation,
|
||||
self._apply_length_check,
|
||||
self._token_padding,
|
||||
self._token_truncation,
|
||||
self._token_len_check,
|
||||
):
|
||||
tokens = validator(tokenizer, tokens)
|
||||
|
||||
|
||||
@@ -115,6 +115,10 @@ class DeepseekV32Tokenizer(CachedHfTokenizer):
|
||||
def max_token_id(self) -> int:
|
||||
return self.tokenizer.max_token_id
|
||||
|
||||
@property
|
||||
def max_chars_per_token(self) -> int:
|
||||
return self.tokenizer.max_chars_per_token
|
||||
|
||||
@property
|
||||
def truncation_side(self) -> str:
|
||||
return self.tokenizer.truncation_side
|
||||
|
||||
@@ -277,6 +277,8 @@ class Grok2Tokenizer(TokenizerLike):
|
||||
self._pad_token_id = self._special_tokens.get(PAD, self._eos_token_id)
|
||||
self._unk_token_id = self._pad_token_id
|
||||
|
||||
self._max_chars_per_token = max(len(tok) for tok in self._token_to_id)
|
||||
|
||||
def num_special_tokens_to_add(self) -> int:
|
||||
return 0
|
||||
|
||||
@@ -312,6 +314,10 @@ class Grok2Tokenizer(TokenizerLike):
|
||||
def max_token_id(self) -> int:
|
||||
return self._tokenizer.n_vocab - 1
|
||||
|
||||
@property
|
||||
def max_chars_per_token(self) -> int:
|
||||
return self._max_chars_per_token
|
||||
|
||||
@property
|
||||
def truncation_side(self) -> str:
|
||||
return self._truncation_side
|
||||
|
||||
@@ -28,6 +28,8 @@ def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
|
||||
tokenizer_len = len(tokenizer)
|
||||
|
||||
max_token_id = max(tokenizer_vocab.values())
|
||||
max_chars_per_token = max(len(tok) for tok in tokenizer_vocab)
|
||||
|
||||
# Some tokenizers (e.g., QwenTokenizer) have special tokens that
|
||||
# are added and included in the implementation of the vocab_size
|
||||
# property, but not in get_vocab(); if there is an implementation
|
||||
@@ -49,6 +51,10 @@ def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
|
||||
def max_token_id(self) -> int:
|
||||
return max_token_id
|
||||
|
||||
@property
|
||||
def max_chars_per_token(self) -> int:
|
||||
return max_chars_per_token
|
||||
|
||||
def get_vocab(self) -> dict[str, int]:
|
||||
return tokenizer_vocab
|
||||
|
||||
|
||||
@@ -272,6 +272,7 @@ class MistralTokenizer(TokenizerLike):
|
||||
# Vocab sorted by token id.
|
||||
self._vocab = self.tokenizer.vocab()
|
||||
self._max_token_id = self.vocab_size - 1
|
||||
self._max_chars_per_token = max(len(tok) for tok in self._vocab)
|
||||
|
||||
# Cache special tokens for faster access.
|
||||
self._special_token_ids = self._get_special_token_ids()
|
||||
@@ -325,6 +326,10 @@ class MistralTokenizer(TokenizerLike):
|
||||
def max_token_id(self) -> int:
|
||||
return self._max_token_id
|
||||
|
||||
@property
|
||||
def max_chars_per_token(self) -> int:
|
||||
return self._max_chars_per_token
|
||||
|
||||
@property
|
||||
def truncation_side(self) -> str:
|
||||
return self.transformers_tokenizer.truncation_side
|
||||
|
||||
@@ -57,6 +57,10 @@ class TokenizerLike(Protocol):
|
||||
def max_token_id(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def max_chars_per_token(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def truncation_side(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user