diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index a65fc35e0..7dff937c0 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -22,9 +22,11 @@ if not current_platform.is_cuda(): pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True) -TEXT_ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B-Instruct", - enforce_eager=True, - disable_log_requests=True) +TEXT_ENGINE_ARGS = AsyncEngineArgs( + model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + disable_log_requests=True, +) VISION_ENGINE_ARGS = AsyncEngineArgs(model="Qwen/Qwen2-VL-2B-Instruct", enforce_eager=True, @@ -41,28 +43,33 @@ VISION_PROMPT = { "prompt": VISION_PROMPT_TEMPLATE, "multi_modal_data": { "image": ImageAsset("stop_sign").pil_image - } + }, } -async def generate(engine: AsyncLLM, - request_id: str, - prompt: PromptType, - output_kind: RequestOutputKind, - max_tokens: int, - n: int = 1, - prompt_logprobs: Optional[int] = None) -> tuple[int, str]: +async def generate( + engine: AsyncLLM, + request_id: str, + prompt: PromptType, + output_kind: RequestOutputKind, + max_tokens: int, + n: int = 1, + prompt_logprobs: Optional[int] = None, + cancel_after: Optional[int] = None, +) -> tuple[int, str]: # Ensure generate doesn't complete too fast for cancellation test. await asyncio.sleep(0.2) count = 0 - sampling_params = SamplingParams(max_tokens=max_tokens, - ignore_eos=True, - output_kind=output_kind, - temperature=0.5, - seed=33, - n=n, - prompt_logprobs=prompt_logprobs) + sampling_params = SamplingParams( + max_tokens=max_tokens, + ignore_eos=True, + output_kind=output_kind, + temperature=0.5, + seed=33, + n=n, + prompt_logprobs=prompt_logprobs, + ) async for out in engine.generate(request_id=request_id, prompt=prompt, sampling_params=sampling_params): @@ -73,20 +80,27 @@ async def generate(engine: AsyncLLM, else: count = num_tokens - await asyncio.sleep(0.) + if cancel_after is not None and count >= cancel_after: + return count, request_id + + await asyncio.sleep(0.0) return count, request_id @pytest.mark.parametrize( "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) -@pytest.mark.parametrize("engine_args,prompt", - [(TEXT_ENGINE_ARGS, TEXT_PROMPT), - (VISION_ENGINE_ARGS, VISION_PROMPT)]) +@pytest.mark.parametrize( + "engine_args,prompt", + [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)], +) @pytest.mark.asyncio -async def test_load(monkeypatch: pytest.MonkeyPatch, - output_kind: RequestOutputKind, - engine_args: AsyncEngineArgs, prompt: PromptType): +async def test_load( + monkeypatch: pytest.MonkeyPatch, + output_kind: RequestOutputKind, + engine_args: AsyncEngineArgs, + prompt: PromptType, +): # TODO(rickyx): Remove monkeypatch once we have a better way to test V1 # so that in the future when we switch, we don't have to change all the # tests. @@ -125,13 +139,17 @@ async def test_load(monkeypatch: pytest.MonkeyPatch, @pytest.mark.parametrize( "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) -@pytest.mark.parametrize("engine_args,prompt", - [(TEXT_ENGINE_ARGS, TEXT_PROMPT), - (VISION_ENGINE_ARGS, VISION_PROMPT)]) +@pytest.mark.parametrize( + "engine_args,prompt", + [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)], +) @pytest.mark.asyncio -async def test_abort(monkeypatch: pytest.MonkeyPatch, - output_kind: RequestOutputKind, - engine_args: AsyncEngineArgs, prompt: PromptType): +async def test_abort( + monkeypatch: pytest.MonkeyPatch, + output_kind: RequestOutputKind, + engine_args: AsyncEngineArgs, + prompt: PromptType, +): with monkeypatch.context() as m, ExitStack() as after: m.setenv("VLLM_USE_V1", "1") @@ -150,8 +168,9 @@ async def test_abort(monkeypatch: pytest.MonkeyPatch, # Create concurrent requests. tasks: list[asyncio.Task] = [] for idx, request_id in enumerate(request_ids): - max_tokens = NUM_EXPECTED_TOKENS_LONG if ( - idx in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS + max_tokens = (NUM_EXPECTED_TOKENS_LONG if + (idx + in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS) n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1 tasks.append( asyncio.create_task( @@ -192,12 +211,17 @@ async def test_abort(monkeypatch: pytest.MonkeyPatch, @pytest.mark.parametrize("n", [1, 3]) -@pytest.mark.parametrize("engine_args,prompt", - [(TEXT_ENGINE_ARGS, TEXT_PROMPT), - (VISION_ENGINE_ARGS, VISION_PROMPT)]) +@pytest.mark.parametrize( + "engine_args,prompt", + [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)], +) @pytest.mark.asyncio -async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int, - engine_args: AsyncEngineArgs, prompt: PromptType): +async def test_finished_flag( + monkeypatch: pytest.MonkeyPatch, + n: int, + engine_args: AsyncEngineArgs, + prompt: PromptType, +): with monkeypatch.context() as m, ExitStack() as after: m.setenv("VLLM_USE_V1", "1") @@ -205,11 +229,13 @@ async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int, engine = AsyncLLM.from_engine_args(engine_args) after.callback(engine.shutdown) - sampling_params = SamplingParams(max_tokens=100, - output_kind=RequestOutputKind.DELTA, - temperature=1.0, - seed=33, - n=n) + sampling_params = SamplingParams( + max_tokens=100, + output_kind=RequestOutputKind.DELTA, + temperature=1.0, + seed=33, + n=n, + ) outputs = [ out async for out in engine.generate(request_id="request-33", @@ -222,6 +248,63 @@ async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int, assert outputs[-1].finished +@pytest.mark.parametrize( + "engine_args,prompt", + [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)], +) +@pytest.mark.asyncio +async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch, + engine_args: AsyncEngineArgs, + prompt: PromptType): + """Test that requests can be cancelled mid-stream.""" + with monkeypatch.context() as m, ExitStack() as after: + m.setenv("VLLM_USE_V1", "1") + + engine = AsyncLLM.from_engine_args(engine_args) + after.callback(engine.shutdown) + + NUM_REQUESTS = 100 + NUM_TOKENS = 1000 + NUM_EXPECTED_TOKENS = 20 + + request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] + + # Create concurrent requests that will be cancelled mid-stream + tasks = [] + for request_id in request_ids: + tasks.append( + asyncio.create_task( + generate( + engine, + request_id, + prompt, + RequestOutputKind.DELTA, + NUM_TOKENS, + cancel_after=NUM_EXPECTED_TOKENS, + ))) + + # Wait for all tasks to complete + results = await asyncio.gather(*tasks) + + # Verify all tasks were cancelled at the expected point + for num_generated_tokens, request_id in results: + assert num_generated_tokens == NUM_EXPECTED_TOKENS, ( + f"{request_id} generated {num_generated_tokens} tokens but " + f"expected to cancel after {NUM_EXPECTED_TOKENS}") + + # Make sure no requests are left hanging + assert not engine.output_processor.has_unfinished_requests() + + # Confirm we can reuse the request id after the cancellations. + request_id = request_ids[0] + task = asyncio.create_task( + generate(engine, request_id, prompt, RequestOutputKind.DELTA, + NUM_EXPECTED_TOKENS)) + num_generated_tokens, request_id = await task + assert num_generated_tokens == NUM_EXPECTED_TOKENS + assert not engine.output_processor.has_unfinished_requests() + + class MockLoggingStatLogger(LoggingStatLogger): def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 089f15aee..7fb36cf59 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -332,8 +332,9 @@ class AsyncLLM(EngineClient): yield out # If the request is disconnected by the client, generate() - # is cancelled. So, we abort the request if we end up here. - except asyncio.CancelledError: + # is cancelled or the generator is garbage collected. So, + # we abort the request if we end up here. + except (asyncio.CancelledError, GeneratorExit): await self.abort(request_id) if self.log_requests: logger.info("Request %s aborted.", request_id)