[Front-end] microbatch tokenization (#19334)

Signed-off-by: zt2370 <ztang2370@gmail.com>
This commit is contained in:
ztang2370
2025-07-08 00:54:10 +08:00
committed by GitHub
parent edd270bc78
commit a37d75bbec
3 changed files with 288 additions and 64 deletions

View File

@@ -7,6 +7,8 @@ from dataclasses import dataclass, field
from typing import Any, Optional
from unittest.mock import MagicMock
import pytest
from vllm.config import MultiModalConfig
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
@@ -73,7 +75,8 @@ def test_async_serving_chat_init():
assert serving_completion.chat_template == CHAT_TEMPLATE
def test_serving_chat_should_set_correct_max_tokens():
@pytest.mark.asyncio
async def test_serving_chat_should_set_correct_max_tokens():
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
@@ -88,6 +91,7 @@ def test_serving_chat_should_set_correct_max_tokens():
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
@@ -98,13 +102,13 @@ def test_serving_chat_should_set_correct_max_tokens():
)
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 93
req.max_tokens = 10
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 10
@@ -143,7 +147,7 @@ def test_serving_chat_should_set_correct_max_tokens():
)
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 10
@@ -151,7 +155,7 @@ def test_serving_chat_should_set_correct_max_tokens():
req.max_tokens = 15
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 10
@@ -159,7 +163,7 @@ def test_serving_chat_should_set_correct_max_tokens():
req.max_tokens = 5
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 5
@@ -198,7 +202,7 @@ def test_serving_chat_should_set_correct_max_tokens():
)
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 93
@@ -206,7 +210,7 @@ def test_serving_chat_should_set_correct_max_tokens():
req.max_tokens = 100
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 93
@@ -214,12 +218,13 @@ def test_serving_chat_should_set_correct_max_tokens():
req.max_tokens = 5
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 5
def test_serving_chat_could_load_correct_generation_config():
@pytest.mark.asyncio
async def test_serving_chat_could_load_correct_generation_config():
mock_model_config = MockModelConfig()
mock_model_config.diff_sampling_param = {
@@ -242,6 +247,7 @@ def test_serving_chat_could_load_correct_generation_config():
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
@@ -252,7 +258,7 @@ def test_serving_chat_could_load_correct_generation_config():
)
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].temperature == 0.5
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
@@ -261,7 +267,7 @@ def test_serving_chat_could_load_correct_generation_config():
req.temperature = 0.1
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].temperature == 0.1
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
@@ -270,13 +276,14 @@ def test_serving_chat_could_load_correct_generation_config():
req.temperature = 0.0
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].temperature == 0.0
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
def test_serving_chat_did_set_correct_cache_salt():
@pytest.mark.asyncio
async def test_serving_chat_did_set_correct_cache_salt():
mock_model_config = MockModelConfig()
mock_engine = MagicMock(spec=MQLLMEngineClient)
@@ -306,11 +313,11 @@ def test_serving_chat_did_set_correct_cache_salt():
# By default cache_salt in the engine prompt is not set
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert "cache_salt" not in mock_engine.generate.call_args.args[0]
# Test with certain cache_salt
req.cache_salt = "test_salt"
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"