[Fix] Add model sequence length into model config (#575)
This commit is contained in:
@@ -107,25 +107,14 @@ async def get_gen_prompt(request) -> str:
|
||||
return prompt
|
||||
|
||||
|
||||
async def check_length(request, prompt, model_config):
|
||||
if hasattr(model_config.hf_config, "max_sequence_length"):
|
||||
context_len = model_config.hf_config.max_sequence_length
|
||||
elif hasattr(model_config.hf_config, "seq_length"):
|
||||
context_len = model_config.hf_config.seq_length
|
||||
elif hasattr(model_config.hf_config, "max_position_embeddings"):
|
||||
context_len = model_config.hf_config.max_position_embeddings
|
||||
elif hasattr(model_config.hf_config, "seq_length"):
|
||||
context_len = model_config.hf_config.seq_length
|
||||
else:
|
||||
context_len = 2048
|
||||
|
||||
async def check_length(request, prompt):
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
token_num = len(input_ids)
|
||||
|
||||
if token_num + request.max_tokens > context_len:
|
||||
if token_num + request.max_tokens > max_model_len:
|
||||
return create_error_response(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
f"This model's maximum context length is {context_len} tokens. "
|
||||
f"This model's maximum context length is {max_model_len} tokens. "
|
||||
f"However, you requested {request.max_tokens + token_num} tokens "
|
||||
f"({token_num} in the messages, "
|
||||
f"{request.max_tokens} in the completion). "
|
||||
@@ -194,7 +183,7 @@ async def create_chat_completion(raw_request: Request):
|
||||
"logit_bias is not currently supported")
|
||||
|
||||
prompt = await get_gen_prompt(request)
|
||||
error_check_ret = await check_length(request, prompt, engine_model_config)
|
||||
error_check_ret = await check_length(request, prompt)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
@@ -591,6 +580,7 @@ if __name__ == "__main__":
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
engine_model_config = asyncio.run(engine.get_model_config())
|
||||
max_model_len = engine_model_config.get_max_model_len()
|
||||
|
||||
# A separate tokenizer to map token IDs to strings.
|
||||
tokenizer = get_tokenizer(engine_args.tokenizer,
|
||||
|
||||
Reference in New Issue
Block a user