[Frontend] OpenAI server: propagate usage accounting to FastAPI middleware layer (#8672)

This commit is contained in:
Adam Tilghman
2024-09-25 00:49:26 -07:00
committed by GitHub
parent 3e073e66f1
commit 1ac3de09cd
3 changed files with 57 additions and 11 deletions

View File

@@ -18,7 +18,9 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse, UsageInfo)
ErrorResponse,
RequestResponseMetadata,
UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
@@ -94,6 +96,10 @@ class OpenAIServingCompletion(OpenAIServing):
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[RequestOutput, None]] = []
try:
@@ -165,13 +171,15 @@ class OpenAIServingCompletion(OpenAIServing):
# Streaming response
if stream:
return self.completion_stream_generator(request,
result_generator,
request_id,
created_time,
model_name,
num_prompts=len(prompts),
tokenizer=tokenizer)
return self.completion_stream_generator(
request,
result_generator,
request_id,
created_time,
model_name,
num_prompts=len(prompts),
tokenizer=tokenizer,
request_metadata=request_metadata)
# Non-streaming response
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
@@ -198,6 +206,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time,
model_name,
tokenizer,
request_metadata,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
@@ -227,6 +236,7 @@ class OpenAIServingCompletion(OpenAIServing):
model_name: str,
num_prompts: int,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n
previous_text_lens = [0] * num_choices * num_prompts
@@ -346,6 +356,14 @@ class OpenAIServingCompletion(OpenAIServing):
exclude_unset=False, exclude_none=True))
yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices
total_prompt_tokens = sum(num_prompt_tokens)
total_completion_tokens = sum(previous_num_tokens)
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e))
@@ -360,6 +378,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time: int,
model_name: str,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> CompletionResponse:
choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0
@@ -433,6 +452,8 @@ class OpenAIServingCompletion(OpenAIServing):
total_tokens=num_prompt_tokens + num_generated_tokens,
)
request_metadata.final_usage_info = usage
return CompletionResponse(
id=request_id,
created=created_time,