Add option to completion API to truncate prompt tokens (#3144)
This commit is contained in:
@@ -4,6 +4,8 @@ from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import conint
|
||||
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest, ErrorResponse,
|
||||
@@ -66,7 +68,8 @@ class OpenAIServing:
|
||||
self.tokenizer = get_tokenizer(
|
||||
engine_model_config.tokenizer,
|
||||
tokenizer_mode=engine_model_config.tokenizer_mode,
|
||||
trust_remote_code=engine_model_config.trust_remote_code)
|
||||
trust_remote_code=engine_model_config.trust_remote_code,
|
||||
truncation_side="left")
|
||||
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Show available models. Right now we only have one model."""
|
||||
@@ -164,15 +167,26 @@ class OpenAIServing:
|
||||
self,
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
prompt: Optional[str] = None,
|
||||
prompt_ids: Optional[List[int]] = None) -> List[int]:
|
||||
prompt_ids: Optional[List[int]] = None,
|
||||
truncate_prompt_tokens: Optional[conint(ge=1)] = None
|
||||
) -> List[int]:
|
||||
if not (prompt or prompt_ids):
|
||||
raise ValueError("Either prompt or prompt_ids should be provided.")
|
||||
if (prompt and prompt_ids):
|
||||
raise ValueError(
|
||||
"Only one of prompt or prompt_ids should be provided.")
|
||||
|
||||
input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
|
||||
prompt).input_ids
|
||||
if prompt_ids is None:
|
||||
tokenizer_kwargs = {} if truncate_prompt_tokens is None else {
|
||||
"truncation": True,
|
||||
"max_length": truncate_prompt_tokens,
|
||||
}
|
||||
input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
|
||||
elif truncate_prompt_tokens is not None:
|
||||
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
||||
else:
|
||||
input_ids = prompt_ids
|
||||
|
||||
token_num = len(input_ids)
|
||||
|
||||
if request.max_tokens is None:
|
||||
|
||||
Reference in New Issue
Block a user