[Feature] Add load generation config from model (#11164)

Signed-off-by: liuyanyi <wolfsonliu@163.com>
Signed-off-by: Yanyi Liu <wolfsonliu@163.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Yanyi Liu
2024-12-19 18:50:38 +08:00
committed by GitHub
parent 98356735ac
commit 5aef49806d
10 changed files with 307 additions and 74 deletions

View File

@@ -1,6 +1,7 @@
import asyncio
from contextlib import suppress
from dataclasses import dataclass
from typing import Optional
from unittest.mock import MagicMock
from vllm.config import MultiModalConfig
@@ -31,6 +32,10 @@ class MockModelConfig:
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
logits_processor_pattern = None
diff_sampling_param: Optional[dict] = None
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
@dataclass
@@ -94,3 +99,59 @@ def test_serving_chat_should_set_correct_max_tokens():
asyncio.run(serving_chat.create_chat_completion(req))
assert mock_engine.generate.call_args.args[1].max_tokens == 10
def test_serving_chat_could_load_correct_generation_config():
mock_model_config = MockModelConfig()
mock_model_config.diff_sampling_param = {
"temperature": 0.5,
"repetition_penalty": 1.05
}
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
# Initialize the serving chat
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
BASE_MODEL_PATHS,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
lora_modules=None,
prompt_adapters=None,
request_logger=None)
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "what is 1+1?"
}],
guided_decoding_backend="outlines",
)
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
assert mock_engine.generate.call_args.args[1].temperature == 0.5
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
# Test the param when user set it
req.temperature = 0.1
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
assert mock_engine.generate.call_args.args[1].temperature == 0.1
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
# Test When temperature==0.0
req.temperature = 0.0
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
assert mock_engine.generate.call_args.args[1].temperature == 0.0
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05