[BugFix] Overhaul async request cancellation (#7111)

This commit is contained in:
Nick Hill
2024-08-06 22:21:41 -07:00
committed by GitHub
parent f9a5600649
commit 9a3f49ae07
11 changed files with 222 additions and 222 deletions

View File

@@ -1,3 +1,4 @@
import asyncio
import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional)
@@ -84,7 +85,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time = int(time.time())
# Schedule the request and get the result generator.
generators: List[AsyncIterator[RequestOutput]] = []
generators: List[AsyncGenerator[RequestOutput, None]] = []
try:
(
lora_request,
@@ -144,7 +145,8 @@ class OpenAIServingCompletion(OpenAIServing):
return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[
int, RequestOutput]] = merge_async_iterators(*generators)
int, RequestOutput]] = merge_async_iterators(
*generators, is_cancelled=raw_request.is_disconnected)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use
@@ -156,7 +158,6 @@ class OpenAIServingCompletion(OpenAIServing):
# Streaming response
if stream:
return self.completion_stream_generator(request,
raw_request,
result_generator,
request_id,
created_time,
@@ -168,10 +169,6 @@ class OpenAIServingCompletion(OpenAIServing):
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
try:
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.async_engine_client.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
for i, final_res in enumerate(final_res_batch):
@@ -194,6 +191,8 @@ class OpenAIServingCompletion(OpenAIServing):
model_name,
tokenizer,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@@ -214,7 +213,6 @@ class OpenAIServingCompletion(OpenAIServing):
async def completion_stream_generator(
self,
request: CompletionRequest,
raw_request: Request,
result_generator: AsyncIterator[Tuple[int, RequestOutput]],
request_id: str,
created_time: int,
@@ -230,12 +228,6 @@ class OpenAIServingCompletion(OpenAIServing):
try:
async for prompt_idx, res in result_generator:
# Abort the request if the client disconnects.
if await raw_request.is_disconnected():
await self.async_engine_client.abort(
f"{request_id}-{prompt_idx}")
raise StopAsyncIteration()
for output in res.outputs:
i = output.index + prompt_idx * num_choices
# TODO(simon): optimize the performance by avoiding full