[Feature] Add load generation config from model (#11164)
Signed-off-by: liuyanyi <wolfsonliu@163.com> Signed-off-by: Yanyi Liu <wolfsonliu@163.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from vllm.config import MultiModalConfig
|
||||
@@ -31,6 +32,10 @@ class MockModelConfig:
|
||||
multimodal_config = MultiModalConfig()
|
||||
hf_config = MockHFConfig()
|
||||
logits_processor_pattern = None
|
||||
diff_sampling_param: Optional[dict] = None
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -94,3 +99,59 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||
|
||||
|
||||
def test_serving_chat_could_load_correct_generation_config():
|
||||
|
||||
mock_model_config = MockModelConfig()
|
||||
mock_model_config.diff_sampling_param = {
|
||||
"temperature": 0.5,
|
||||
"repetition_penalty": 1.05
|
||||
}
|
||||
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
# Initialize the serving chat
|
||||
serving_chat = OpenAIServingChat(mock_engine,
|
||||
mock_model_config,
|
||||
BASE_MODEL_PATHS,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=None)
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}],
|
||||
guided_decoding_backend="outlines",
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].temperature == 0.5
|
||||
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
||||
|
||||
# Test the param when user set it
|
||||
req.temperature = 0.1
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].temperature == 0.1
|
||||
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
||||
|
||||
# Test When temperature==0.0
|
||||
req.temperature = 0.0
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].temperature == 0.0
|
||||
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
||||
|
||||
Reference in New Issue
Block a user