[Front-end] microbatch tokenization (#19334)
Signed-off-by: zt2370 <ztang2370@gmail.com>
This commit is contained in:
@@ -7,6 +7,8 @@ from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import MultiModalConfig
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
@@ -73,7 +75,8 @@ def test_async_serving_chat_init():
|
||||
assert serving_completion.chat_template == CHAT_TEMPLATE
|
||||
|
||||
|
||||
def test_serving_chat_should_set_correct_max_tokens():
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
@@ -88,6 +91,7 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None)
|
||||
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
@@ -98,13 +102,13 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 93
|
||||
|
||||
req.max_tokens = 10
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||
|
||||
@@ -143,7 +147,7 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||
|
||||
@@ -151,7 +155,7 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
req.max_tokens = 15
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||
|
||||
@@ -159,7 +163,7 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
req.max_tokens = 5
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 5
|
||||
|
||||
@@ -198,7 +202,7 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 93
|
||||
|
||||
@@ -206,7 +210,7 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
req.max_tokens = 100
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 93
|
||||
|
||||
@@ -214,12 +218,13 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
req.max_tokens = 5
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 5
|
||||
|
||||
|
||||
def test_serving_chat_could_load_correct_generation_config():
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_chat_could_load_correct_generation_config():
|
||||
|
||||
mock_model_config = MockModelConfig()
|
||||
mock_model_config.diff_sampling_param = {
|
||||
@@ -242,6 +247,7 @@ def test_serving_chat_could_load_correct_generation_config():
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None)
|
||||
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
@@ -252,7 +258,7 @@ def test_serving_chat_could_load_correct_generation_config():
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].temperature == 0.5
|
||||
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
||||
@@ -261,7 +267,7 @@ def test_serving_chat_could_load_correct_generation_config():
|
||||
req.temperature = 0.1
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].temperature == 0.1
|
||||
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
||||
@@ -270,13 +276,14 @@ def test_serving_chat_could_load_correct_generation_config():
|
||||
req.temperature = 0.0
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].temperature == 0.0
|
||||
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
||||
|
||||
|
||||
def test_serving_chat_did_set_correct_cache_salt():
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_chat_did_set_correct_cache_salt():
|
||||
mock_model_config = MockModelConfig()
|
||||
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
@@ -306,11 +313,11 @@ def test_serving_chat_did_set_correct_cache_salt():
|
||||
|
||||
# By default cache_salt in the engine prompt is not set
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
assert "cache_salt" not in mock_engine.generate.call_args.args[0]
|
||||
|
||||
# Test with certain cache_salt
|
||||
req.cache_salt = "test_salt"
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
await serving_chat.create_chat_completion(req)
|
||||
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"
|
||||
|
||||
Reference in New Issue
Block a user