[Fix] Add model sequence length into model config (#575)

This commit is contained in:
Zhuohan Li
2023-07-25 23:46:30 -07:00
committed by GitHub
parent 82ad323dee
commit 58a072be15
3 changed files with 27 additions and 18 deletions

View File

@@ -107,25 +107,14 @@ async def get_gen_prompt(request) -> str:
return prompt
async def check_length(request, prompt, model_config):
if hasattr(model_config.hf_config, "max_sequence_length"):
context_len = model_config.hf_config.max_sequence_length
elif hasattr(model_config.hf_config, "seq_length"):
context_len = model_config.hf_config.seq_length
elif hasattr(model_config.hf_config, "max_position_embeddings"):
context_len = model_config.hf_config.max_position_embeddings
elif hasattr(model_config.hf_config, "seq_length"):
context_len = model_config.hf_config.seq_length
else:
context_len = 2048
async def check_length(request, prompt):
input_ids = tokenizer(prompt).input_ids
token_num = len(input_ids)
if token_num + request.max_tokens > context_len:
if token_num + request.max_tokens > max_model_len:
return create_error_response(
HTTPStatus.BAD_REQUEST,
f"This model's maximum context length is {context_len} tokens. "
f"This model's maximum context length is {max_model_len} tokens. "
f"However, you requested {request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
@@ -194,7 +183,7 @@ async def create_chat_completion(raw_request: Request):
"logit_bias is not currently supported")
prompt = await get_gen_prompt(request)
error_check_ret = await check_length(request, prompt, engine_model_config)
error_check_ret = await check_length(request, prompt)
if error_check_ret is not None:
return error_check_ret
@@ -591,6 +580,7 @@ if __name__ == "__main__":
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
engine_model_config = asyncio.run(engine.get_model_config())
max_model_len = engine_model_config.get_max_model_len()
# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(engine_args.tokenizer,