[HTTP Server] Make model param optional in request (#13568)
This commit is contained in:
@@ -213,7 +213,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
model: str
|
||||
model: Optional[str] = None
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
logprobs: Optional[bool] = False
|
||||
@@ -642,7 +642,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
class CompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/completions/create
|
||||
model: str
|
||||
model: Optional[str] = None
|
||||
prompt: Union[List[int], List[List[int]], str, List[str]]
|
||||
best_of: Optional[int] = None
|
||||
echo: Optional[bool] = False
|
||||
@@ -907,7 +907,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/embeddings
|
||||
model: str
|
||||
model: Optional[str] = None
|
||||
input: Union[List[int], List[List[int]], str, List[str]]
|
||||
encoding_format: Literal["float", "base64"] = "float"
|
||||
dimensions: Optional[int] = None
|
||||
@@ -939,7 +939,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||
|
||||
|
||||
class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
model: Optional[str] = None
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
|
||||
encoding_format: Literal["float", "base64"] = "float"
|
||||
@@ -1007,7 +1007,7 @@ PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest]
|
||||
|
||||
|
||||
class ScoreRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
model: Optional[str] = None
|
||||
text_1: Union[List[str], str]
|
||||
text_2: Union[List[str], str]
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
@@ -1031,7 +1031,7 @@ class ScoreRequest(OpenAIBaseModel):
|
||||
|
||||
|
||||
class RerankRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
model: Optional[str] = None
|
||||
query: str
|
||||
documents: List[str]
|
||||
top_n: int = Field(default_factory=lambda: 0)
|
||||
@@ -1345,7 +1345,7 @@ class BatchRequestOutput(OpenAIBaseModel):
|
||||
|
||||
|
||||
class TokenizeCompletionRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
model: Optional[str] = None
|
||||
prompt: str
|
||||
|
||||
add_special_tokens: bool = Field(
|
||||
@@ -1357,7 +1357,7 @@ class TokenizeCompletionRequest(OpenAIBaseModel):
|
||||
|
||||
|
||||
class TokenizeChatRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
model: Optional[str] = None
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
|
||||
add_generation_prompt: bool = Field(
|
||||
@@ -1423,7 +1423,7 @@ class TokenizeResponse(OpenAIBaseModel):
|
||||
|
||||
|
||||
class DetokenizeRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
model: Optional[str] = None
|
||||
tokens: List[int]
|
||||
|
||||
|
||||
@@ -1456,7 +1456,7 @@ class TranscriptionRequest(OpenAIBaseModel):
|
||||
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
|
||||
"""
|
||||
|
||||
model: str
|
||||
model: Optional[str] = None
|
||||
"""ID of the model to use.
|
||||
"""
|
||||
|
||||
|
||||
@@ -141,7 +141,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
model_name = self.models.model_name(lora_request)
|
||||
model_name = self._get_model_name(request.model, lora_request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
|
||||
@@ -166,7 +166,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
model_name = self.models.model_name(lora_request)
|
||||
model_name = self._get_model_name(request.model, lora_request)
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
|
||||
@@ -83,7 +83,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported")
|
||||
|
||||
model_name = request.model
|
||||
model_name = self._get_model_name(request.model)
|
||||
request_id = f"embd-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
|
||||
@@ -523,5 +523,16 @@ class OpenAIServing:
|
||||
return logprob.decoded_token
|
||||
return tokenizer.decode(token_id)
|
||||
|
||||
def _is_model_supported(self, model_name):
|
||||
def _is_model_supported(self, model_name) -> bool:
|
||||
if not model_name:
|
||||
return True
|
||||
return self.models.is_base_model(model_name)
|
||||
|
||||
def _get_model_name(self,
|
||||
model_name: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None) -> str:
|
||||
if lora_request:
|
||||
return lora_request.lora_name
|
||||
if model_name is None:
|
||||
return self.models.base_model_paths[0].name
|
||||
return model_name
|
||||
|
||||
@@ -95,7 +95,7 @@ class OpenAIServingModels:
|
||||
if isinstance(load_result, ErrorResponse):
|
||||
raise ValueError(load_result.message)
|
||||
|
||||
def is_base_model(self, model_name):
|
||||
def is_base_model(self, model_name) -> bool:
|
||||
return any(model.name == model_name for model in self.base_model_paths)
|
||||
|
||||
def model_name(self, lora_request: Optional[LoRARequest] = None) -> str:
|
||||
|
||||
@@ -79,7 +79,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported")
|
||||
|
||||
model_name = request.model
|
||||
model_name = self._get_model_name(request.model)
|
||||
request_id = f"pool-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
|
||||
@@ -318,7 +318,7 @@ class ServingScores(OpenAIServing):
|
||||
final_res_batch,
|
||||
request_id,
|
||||
created_time,
|
||||
request.model,
|
||||
self._get_model_name(request.model),
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
@@ -358,7 +358,7 @@ class ServingScores(OpenAIServing):
|
||||
request.truncate_prompt_tokens,
|
||||
)
|
||||
return self.request_output_to_rerank_response(
|
||||
final_res_batch, request_id, request.model, documents, top_n)
|
||||
final_res_batch, request_id, self._get_model_name(request.model), documents, top_n)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
|
||||
Reference in New Issue
Block a user