[Frontend] generation_config.json for maximum tokens(#12242)
Signed-off-by: Matthew Hendrey <matthew.hendrey@gmail.com> Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Yuan Tang <terrytangyuan@gmail.com> Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: shangmingc <caishangming@linux.alibaba.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Yuan Tang <terrytangyuan@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -103,6 +103,116 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||
|
||||
# Setting server's max_tokens in the generation_config.json
|
||||
# lower than context_window - prompt_tokens
|
||||
mock_model_config = MockModelConfig()
|
||||
mock_model_config.diff_sampling_param = {
|
||||
"max_tokens": 10 # Setting server-side max_tokens limit
|
||||
}
|
||||
|
||||
# Reinitialize the engine with new settings
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
# Initialize the serving chat
|
||||
models = OpenAIServingModels(engine_client=mock_engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=mock_model_config)
|
||||
serving_chat = OpenAIServingChat(mock_engine,
|
||||
mock_model_config,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None)
|
||||
|
||||
# Test Case 1: No max_tokens specified in request
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}],
|
||||
guided_decoding_backend="outlines",
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||
|
||||
# Test Case 2: Request's max_tokens set higher than server accepts
|
||||
req.max_tokens = 15
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||
|
||||
# Test Case 3: Request's max_tokens set lower than server accepts
|
||||
req.max_tokens = 5
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 5
|
||||
|
||||
# Setting server's max_tokens in the generation_config.json
|
||||
# higher than context_window - prompt_tokens
|
||||
mock_model_config = MockModelConfig()
|
||||
mock_model_config.diff_sampling_param = {
|
||||
"max_tokens": 200 # Setting server-side max_tokens limit
|
||||
}
|
||||
|
||||
# Reinitialize the engine with new settings
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
# Initialize the serving chat
|
||||
models = OpenAIServingModels(engine_client=mock_engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=mock_model_config)
|
||||
serving_chat = OpenAIServingChat(mock_engine,
|
||||
mock_model_config,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None)
|
||||
|
||||
# Test case 1: No max_tokens specified, defaults to context_window
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}],
|
||||
guided_decoding_backend="outlines",
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 93
|
||||
|
||||
# Test Case 2: Request's max_tokens set higher than server accepts
|
||||
req.max_tokens = 100
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 93
|
||||
|
||||
# Test Case 3: Request's max_tokens set lower than server accepts
|
||||
req.max_tokens = 5
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 5
|
||||
|
||||
|
||||
def test_serving_chat_could_load_correct_generation_config():
|
||||
|
||||
|
||||
Reference in New Issue
Block a user