[Feature] add session based streaming input support to v1 (#28973)

Signed-off-by: Joshua Deng <joshuakdeng@gmail.com>
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Signed-off-by: Nick Hill <nickhill123@gmail.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Joshua Deng
2026-01-24 13:06:28 -07:00
committed by GitHub
parent d4dbb7af63
commit 91601ff478
16 changed files with 2151 additions and 63 deletions

View File

@@ -75,6 +75,7 @@ class EngineCoreRequest(
priority: int = 0
trace_headers: Mapping[str, str] | None = None
resumable: bool = False
# The user-provided request ID. This field is set internally,
# copied from the provided request_id that's originally assigned

View File

@@ -7,11 +7,13 @@ import time
import warnings
from collections.abc import AsyncGenerator, Iterable, Mapping
from copy import copy
from typing import Any, cast
from dataclasses import dataclass
from typing import Any
import torch
import vllm.envs as envs
from vllm import TokensPrompt
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
@@ -20,11 +22,11 @@ from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.renderers import RendererLike
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike
from vllm.tracing import init_tracer
@@ -38,6 +40,7 @@ from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.v1.engine.input_processor import InputProcessor
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.utils import get_prompt_text
from vllm.v1.executor import Executor
from vllm.v1.metrics.loggers import (
StatLoggerFactory,
@@ -50,6 +53,30 @@ from vllm.v1.metrics.stats import IterationStats
logger = init_logger(__name__)
@dataclass
class StreamingInput:
"""Input data for a streaming generation request.
This is used with generate() to support multi-turn streaming sessions
where inputs are provided via an async generator.
"""
prompt: PromptType
sampling_params: SamplingParams | None = None
class InputStreamError(Exception):
"""Wrapper for errors from the input stream generator.
This is used to propagate errors from the user's input generator
without wrapping them in EngineGenerateError.
"""
def __init__(self, cause: Exception):
self.cause = cause
super().__init__(str(cause))
class AsyncLLM(EngineClient):
def __init__(
self,
@@ -261,7 +288,7 @@ class AsyncLLM(EngineClient):
async def add_request(
self,
request_id: str,
prompt: EngineCoreRequest | PromptType,
prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
lora_request: LoRARequest | None = None,
@@ -297,6 +324,20 @@ class AsyncLLM(EngineClient):
tokenization_kwargs,
)
if isinstance(prompt, AsyncGenerator):
# Streaming input case.
return await self._add_streaming_input_request(
request_id,
prompt,
params,
arrival_time,
lora_request,
tokenization_kwargs,
trace_headers,
priority,
data_parallel_rank,
)
# Convert Input --> Request.
if isinstance(prompt, EngineCoreRequest):
request = prompt
@@ -322,10 +363,7 @@ class AsyncLLM(EngineClient):
priority,
data_parallel_rank,
)
if isinstance(prompt, str):
prompt_text = prompt
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
prompt_text = get_prompt_text(prompt)
self.input_processor.assign_request_id(request)
@@ -380,6 +418,104 @@ class AsyncLLM(EngineClient):
if self.log_requests:
logger.info("Added request %s.", request.request_id)
async def _add_streaming_input_request(
self,
request_id: str,
input_stream: AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
lora_request: LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
data_parallel_rank: int | None = None,
) -> RequestOutputCollector:
self._validate_streaming_input_sampling_params(sampling_params)
inputs = dict(
arrival_time=arrival_time,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
data_parallel_rank=data_parallel_rank,
)
if not sampling_params.skip_clone:
sampling_params = sampling_params.clone()
sampling_params.skip_clone = True
# Create request for validation, also used as the finished signal
# once the input stream is closed.
final_req = self.input_processor.process_inputs(
request_id=request_id,
prompt=TokensPrompt(prompt_token_ids=[0]),
params=sampling_params,
**inputs, # type: ignore[arg-type]
)
self.input_processor.assign_request_id(final_req)
internal_req_id = final_req.request_id
queue = RequestOutputCollector(sampling_params.output_kind, internal_req_id)
async def handle_inputs():
cancelled = False
try:
async for input_chunk in input_stream:
sp = input_chunk.sampling_params
if sp:
self._validate_streaming_input_sampling_params(sp)
else:
sp = sampling_params
req = self.input_processor.process_inputs(
request_id=internal_req_id,
prompt=input_chunk.prompt,
params=sp,
resumable=True,
**inputs, # type: ignore[arg-type]
)
req.external_req_id = request_id
if req.prompt_embeds is not None:
raise ValueError(
"prompt_embeds not supported for streaming inputs"
)
prompt_text = get_prompt_text(input_chunk.prompt)
await self._add_request(req, prompt_text, None, 0, queue)
except (asyncio.CancelledError, GeneratorExit):
cancelled = True
except Exception as error:
# Wrap in InputStreamError so generate() can propagate it
# without wrapping in EngineGenerateError.
queue.put(InputStreamError(error))
finally:
queue._input_stream_task = None
if not cancelled:
# Send empty final request to indicate that inputs have
# finished. Don't send if cancelled (session was aborted).
await self._add_request(final_req, None, None, 0, queue)
# Ensure output handler is running.
self._run_output_handler()
queue._input_stream_task = asyncio.create_task(handle_inputs())
return queue
@staticmethod
def _validate_streaming_input_sampling_params(
params: SamplingParams | PoolingParams,
):
if (
not isinstance(params, SamplingParams)
or params.n > 1
or params.output_kind == RequestOutputKind.FINAL_ONLY
or params.stop
):
raise ValueError(
"Input streaming not currently supported "
"for pooling models, n > 1, request_kind = FINAL_ONLY "
"or with stop strings."
)
# TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion
# requests we don't need to send multiple messages to core proc,
@@ -387,7 +523,7 @@ class AsyncLLM(EngineClient):
# re-multiplexed in the API server anyhow.
async def generate(
self,
prompt: EngineCoreRequest | PromptType,
prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams,
request_id: str,
*,
@@ -437,9 +573,10 @@ class AsyncLLM(EngineClient):
# Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished.
finished = out.finished
assert isinstance(out, RequestOutput)
yield out
finished = out.finished
if out is not STREAM_FINISHED:
yield out
# If the request is disconnected by the client, generate()
# is cancelled or the generator is garbage collected. So,
@@ -463,6 +600,14 @@ class AsyncLLM(EngineClient):
logger.info("Request %s failed (bad request): %s.", request_id, e)
raise
# Error from input stream generator - propagate directly.
except InputStreamError as e:
if q is not None:
await self.abort(q.request_id, internal=True)
if self.log_requests:
logger.info("Request %s failed (input error): %s.", request_id, e)
raise e.cause from e
# Unexpected error in the generate() task (possibly recoverable).
except Exception as e:
if q is not None:
@@ -478,6 +623,9 @@ class AsyncLLM(EngineClient):
)
logger.info("Request %s failed due to %s.", request_id, s)
raise EngineGenerateError() from e
finally:
if q is not None:
q.close()
def _run_output_handler(self):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
@@ -703,6 +851,9 @@ class AsyncLLM(EngineClient):
if self.log_requests:
logger.info("Request %s failed.", request_id)
raise EngineGenerateError() from e
finally:
if q is not None:
q.close()
@property
def tokenizer(self) -> TokenizerLike | None:

View File

@@ -459,6 +459,7 @@ class InputProcessor:
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
data_parallel_rank: int | None = None,
resumable: bool = False,
) -> EngineCoreRequest:
self._validate_lora(lora_request)
self._validate_params(params)
@@ -603,6 +604,7 @@ class InputProcessor:
priority=priority,
data_parallel_rank=data_parallel_rank,
trace_headers=trace_headers,
resumable=resumable,
)
def _validate_model_inputs(

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from collections import defaultdict
from collections import defaultdict, deque
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, cast
@@ -12,6 +12,7 @@ import torch
from vllm.lora.request import LoRARequest
from vllm.outputs import (
STREAM_FINISHED,
CompletionOutput,
PoolingOutput,
PoolingRequestOutput,
@@ -51,6 +52,8 @@ class RequestOutputCollector:
self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
self.ready = asyncio.Event()
self._input_stream_task: asyncio.Task | None = None
def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None:
"""Non-blocking put operation."""
if self.output is None or isinstance(output, Exception):
@@ -87,6 +90,16 @@ class RequestOutputCollector:
raise output
return output
def close(self):
if self._input_stream_task is not None:
self._input_stream_task.cancel()
self._input_stream_task = None
def __del__(self):
if (task := self._input_stream_task) is not None:
task.get_loop().call_soon_threadsafe(task.cancel)
self._input_stream_task = None
@dataclass
class OutputProcessorOutput:
@@ -94,6 +107,20 @@ class OutputProcessorOutput:
reqs_to_abort: list[str]
@dataclass
class StreamingUpdate:
"""Streaming input update data for output processor.
Contains the incremental prompt data to be applied to a request state
when the current sub-request completes.
"""
prompt: str | None
prompt_token_ids: list[int] | None
arrival_time: float
final: bool = False
class RequestState:
def __init__(
self,
@@ -116,6 +143,7 @@ class RequestState:
top_p: float | None = None,
n: int | None = None,
temperature: float | None = None,
stream_input: bool = False,
):
self.request_id = request_id
self.external_req_id = external_req_id
@@ -146,6 +174,31 @@ class RequestState:
self.stream_interval = stream_interval
self.sent_tokens_offset = 0 # Offset of sent tokens
# Streaming input queue
self.streaming_input = stream_input
self.input_chunk_queue: deque[StreamingUpdate] | None = (
deque() if stream_input else None
)
def apply_streaming_update(self, update: StreamingUpdate) -> None:
# Apply the update to the request state.
self.streaming_input = not update.final
# TODO also include relevant output tokens in new prompt here
# (match scheduler behavior).
if update.prompt:
self.prompt = (
(self.prompt + update.prompt) if self.prompt else update.prompt
)
if self.prompt_token_ids:
self.prompt_token_ids.extend(update.prompt_token_ids or ())
else:
self.prompt_token_ids = update.prompt_token_ids or []
assert self.prompt_token_ids is not None
self.prompt_len = len(self.prompt_token_ids)
if self.stats is not None:
self.stats.arrival_time = update.arrival_time
self.is_prefilling = True
@classmethod
def from_new_request(
cls,
@@ -205,6 +258,7 @@ class RequestState:
queue=queue,
log_stats=log_stats,
stream_interval=stream_interval,
stream_input=request.resumable,
)
def make_request_output(
@@ -405,7 +459,6 @@ class OutputProcessor:
a parent request, in which case the associated child requests are aborted
also.
"""
internal_req_ids = []
for request_id in request_ids:
if internal:
@@ -464,8 +517,10 @@ class OutputProcessor:
queue: RequestOutputCollector | None = None,
) -> None:
request_id = request.request_id
if request_id in self.request_states:
raise ValueError(f"Request id {request_id} already running.")
req_state = self.request_states.get(request_id)
if req_state is not None:
self._update_streaming_request_state(req_state, request, prompt)
return
req_state = RequestState.from_new_request(
tokenizer=self.tokenizer,
@@ -486,6 +541,39 @@ class OutputProcessor:
# Track the external_req_id -> [internal_req_id, ...] mapping
self.external_req_ids[req_state.external_req_id].append(request_id)
def _update_streaming_request_state(
self, req_state: RequestState, request: EngineCoreRequest, prompt: str | None
) -> None:
"""Queue a streaming update instead of immediately applying it."""
if not request.resumable:
# Final request - just mark completion, don't add its dummy tokens.
if req_state.input_chunk_queue is None:
# Engine already finished - emit final output and clean up.
self._finish_request(req_state)
if req_state.queue is not None:
# Emit a final output with finished=True
# to unblock the generate() loop.
req_state.queue.put(STREAM_FINISHED)
elif req_state.input_chunk_queue:
req_state.input_chunk_queue[-1].final = True
else:
req_state.streaming_input = False
return
update = StreamingUpdate(
prompt=prompt,
prompt_token_ids=request.prompt_token_ids,
arrival_time=request.arrival_time,
)
# Apply request updates now if the last input already completed.
if req_state.input_chunk_queue is None:
req_state.apply_streaming_update(update)
req_state.input_chunk_queue = deque()
else:
# Queue the streaming update otherwise.
req_state.input_chunk_queue.append(update)
def process_outputs(
self,
engine_core_outputs: list[EngineCoreOutput],
@@ -561,6 +649,9 @@ class OutputProcessor:
kv_transfer_params,
routed_experts,
):
if req_state.streaming_input:
request_output.finished = False
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().
req_state.queue.put(request_output)
@@ -570,36 +661,48 @@ class OutputProcessor:
# Free completed requests.
if finish_reason is not None:
self.request_states.pop(req_id)
if req_state.streaming_input:
if req_state.input_chunk_queue:
update = req_state.input_chunk_queue.popleft()
req_state.apply_streaming_update(update)
else:
req_state.input_chunk_queue = None
else:
self._finish_request(req_state)
if not engine_core_output.finished:
# If req not finished in EngineCore, but Detokenizer
# detected stop string, abort needed in EngineCore.
reqs_to_abort.append(req_id)
internal_ids = self.external_req_ids[req_state.external_req_id]
internal_ids.remove(req_id)
if not internal_ids:
del self.external_req_ids[req_state.external_req_id]
# Remove parent request if applicable.
parent_req = req_state.parent_req
if parent_req and not parent_req.child_requests:
self.parent_requests.pop(parent_req.request_id, None)
if not self.request_states:
self._requests_drained.set()
if not engine_core_output.finished:
# If req not finished in EngineCore, but Detokenizer
# detected stop string, abort needed in EngineCore.
reqs_to_abort.append(req_id)
# Track per-request stats
self._update_stats_from_finished(
req_state, finish_reason, iteration_stats
)
if self.tracer:
self.do_tracing(engine_core_output, req_state, iteration_stats)
# Track per-request stats
self._update_stats_from_finished(
req_state, finish_reason, iteration_stats
)
if self.tracer:
self.do_tracing(engine_core_output, req_state, iteration_stats)
return OutputProcessorOutput(
request_outputs=request_outputs,
reqs_to_abort=reqs_to_abort,
)
def _finish_request(self, req_state: RequestState) -> None:
req_id = req_state.request_id
self.request_states.pop(req_id)
internal_ids = self.external_req_ids[req_state.external_req_id]
internal_ids.remove(req_id)
if not internal_ids:
del self.external_req_ids[req_state.external_req_id]
# Remove parent request if applicable.
parent_req = req_state.parent_req
if parent_req and not parent_req.child_requests:
self.parent_requests.pop(parent_req.request_id, None)
if not self.request_states:
self._requests_drained.set()
def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
self.lora_states.update_scheduler_stats(scheduler_stats)

View File

@@ -4,12 +4,12 @@
import contextlib
import os
import weakref
from collections.abc import Callable, Iterator
from collections.abc import Callable, Iterator, Mapping
from dataclasses import dataclass
from enum import Enum, auto
from multiprocessing import Process, connection
from multiprocessing.process import BaseProcess
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, cast
from unittest.mock import patch
import msgspec
@@ -224,6 +224,14 @@ def get_device_indices(
return value
def get_prompt_text(prompt: Any) -> str | None:
if isinstance(prompt, str):
return prompt
if isinstance(prompt, Mapping):
return cast(str | None, prompt.get("prompt"))
return None
class CoreEngineActorManager:
"""
Utility class to handle creation, readiness, and shutdown