[BugFix] Overhaul async request cancellation (#7111)

This commit is contained in:
Nick Hill
2024-08-06 22:21:41 -07:00
committed by GitHub
parent f9a5600649
commit 9a3f49ae07
11 changed files with 222 additions and 222 deletions

View File

@@ -1,3 +1,4 @@
import asyncio
import time
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
from typing import Sequence as GenericSequence
@@ -29,7 +30,7 @@ from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.utils import random_uuid
from vllm.utils import iterate_with_cancellation, random_uuid
logger = init_logger(__name__)
@@ -176,18 +177,20 @@ class OpenAIServingChat(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
if raw_request:
result_generator = iterate_with_cancellation(
result_generator, raw_request.is_disconnected)
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation, tokenizer)
else:
try:
return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id,
conversation, tokenizer)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
try:
return await self.chat_completion_full_generator(
request, result_generator, request_id, conversation, tokenizer)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
@@ -422,7 +425,6 @@ class OpenAIServingChat(OpenAIServing):
async def chat_completion_full_generator(
self,
request: ChatCompletionRequest,
raw_request: Optional[Request],
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
@@ -433,12 +435,12 @@ class OpenAIServingChat(OpenAIServing):
created_time = int(time.time())
final_res: Optional[RequestOutput] = None
async for res in result_generator:
if raw_request is not None and await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.async_engine_client.abort(request_id)
return self.create_error_response("Client disconnected")
final_res = res
try:
async for res in result_generator:
final_res = res
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
assert final_res is not None
choices: List[ChatCompletionResponseChoice] = []