[Feature] add session based streaming input support to v1 (#28973)

Signed-off-by: Joshua Deng <joshuakdeng@gmail.com>
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Signed-off-by: Nick Hill <nickhill123@gmail.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Joshua Deng
2026-01-24 13:06:28 -07:00
committed by GitHub
parent d4dbb7af63
commit 91601ff478
16 changed files with 2151 additions and 63 deletions

View File

@@ -192,6 +192,16 @@ class RequestOutput:
)
# Sentinel to indicate request is finished, used with streaming inputs.
STREAM_FINISHED = RequestOutput(
request_id="",
prompt=None,
prompt_token_ids=None,
prompt_logprobs=None,
outputs=[],
finished=True,
)
_O = TypeVar("_O", default=PoolingOutput)

View File

@@ -2,8 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import time
from collections import defaultdict
from collections import defaultdict, deque
from collections.abc import Iterable
from dataclasses import replace
from typing import Any
import numpy as np
@@ -49,12 +50,9 @@ from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
from vllm.v1.metrics.perf import ModelMetrics, PerfStats
from vllm.v1.metrics.stats import (
PrefixCacheStats,
SchedulerStats,
)
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.request import Request, RequestStatus, StreamingUpdate
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import record_function_or_nullcontext
@@ -166,6 +164,10 @@ class Scheduler(SchedulerInterface):
# This is flushed at the end of each scheduling step.
self.finished_req_ids: set[str] = set()
# Counter for requests waiting for streaming input. Used to calculate
# number of unfinished requests
self.num_waiting_for_streaming_input: int = 0
# KV Connector: requests in process of async KV loading or recving
self.finished_recving_kv_req_ids: set[str] = set()
self.failed_recving_kv_req_ids: set[str] = set()
@@ -569,6 +571,13 @@ class Scheduler(SchedulerInterface):
skipped_waiting_requests.prepend_request(request)
continue
# Streaming: skip request if still waiting for next streaming req.
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
assert not request.streaming_queue
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Check that adding the request still respects the max_loras
# constraint.
if (
@@ -929,6 +938,51 @@ class Scheduler(SchedulerInterface):
# it will also affect the scheduler output.
self.finished_req_ids = set()
def _update_request_as_session(
self, session: Request, update: StreamingUpdate
) -> None:
"""
Updates the waiting session with the next streaming update.
Discards the last sampled output token from the prior input chunk.
"""
# Current streaming input behaviour: Keep only computed output tokens
# (discard final sampled output token).
num_computed_tokens = session.num_computed_tokens
kept_output_tokens = session._all_token_ids[
session.num_prompt_tokens : num_computed_tokens
]
del session._all_token_ids[num_computed_tokens:]
session._output_token_ids.clear()
assert session.prompt_token_ids is not None
# Extend prompt with kept output tokens.
session.prompt_token_ids.extend(kept_output_tokens)
if update.mm_features:
base = session.num_tokens
for mm_feature in update.mm_features:
mm_feature.mm_position = replace(
mm_feature.mm_position, offset=mm_feature.mm_position.offset + base
)
session.mm_features.extend(update.mm_features)
session._all_token_ids.extend(update.prompt_token_ids or ())
session.prompt_token_ids.extend(update.prompt_token_ids or ())
# Update block hashes for the new tokens
# (mirrors Request.append_output_token_ids)
if session.get_hash_new_full_blocks is not None:
session.block_hashes.extend(session.get_hash_new_full_blocks())
session.num_prompt_tokens = len(session.prompt_token_ids)
session.arrival_time = update.arrival_time
session.sampling_params = update.sampling_params
if session.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
self.num_waiting_for_streaming_input -= 1
session.status = RequestStatus.WAITING
if self.log_stats:
session.record_event(EngineCoreEventType.QUEUED)
def _make_cached_request_data(
self,
running_reqs: list[Request],
@@ -1271,9 +1325,17 @@ class Scheduler(SchedulerInterface):
stopped = True
routed_experts = None
finish_reason = None
if stopped:
routed_experts = self._get_routed_experts(request)
kv_transfer_params = self._free_request(request)
# Capture finish_reason BEFORE _handle_stopped_request, which may
# reset the status to WAITING for streaming requests that continue.
finish_reason = request.get_finished_reason()
finished = self._handle_stopped_request(request)
if finished:
kv_transfer_params = self._free_request(request)
if status_before_stop == RequestStatus.RUNNING:
stopped_running_reqs.add(request)
else:
@@ -1315,7 +1377,7 @@ class Scheduler(SchedulerInterface):
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids,
finish_reason=request.get_finished_reason(),
finish_reason=finish_reason,
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
pooling_output=pooler_output,
@@ -1410,6 +1472,24 @@ class Scheduler(SchedulerInterface):
return engine_core_outputs
def _handle_stopped_request(self, request: Request) -> bool:
"""Return True if finished (can be False for resumable requests)."""
if not request.resumable:
return True
if request.streaming_queue:
update = request.streaming_queue.popleft()
if update is None:
# Streaming request finished.
return True
self._update_request_as_session(request, update)
else:
request.status = RequestStatus.WAITING_FOR_STREAMING_REQ
self.num_waiting_for_streaming_input += 1
self.waiting.add_request(request)
return False
def _get_routed_experts(self, request: Request) -> np.ndarray | None:
if not self.vllm_config.model_config.enable_return_routed_experts:
return None
@@ -1535,10 +1615,26 @@ class Scheduler(SchedulerInterface):
return len(self.running), len(self.waiting)
def add_request(self, request: Request) -> None:
self.waiting.add_request(request)
self.requests[request.request_id] = request
if self.log_stats:
request.record_event(EngineCoreEventType.QUEUED)
existing = self.requests.get(request.request_id)
if existing is not None:
update = StreamingUpdate.from_request(request)
if existing.status != RequestStatus.WAITING_FOR_STREAMING_REQ:
assert existing.streaming_queue is not None, "duplicate request id"
# Queue next input chunk (or finished sentinel).
existing.streaming_queue.append(update)
elif update is not None:
# Commence next input chunk.
self._update_request_as_session(existing, update)
else:
# Streaming-input session finished.
self.finish_requests(request.request_id, RequestStatus.FINISHED_ABORTED)
else:
if request.resumable:
request.streaming_queue = deque()
self.waiting.add_request(request)
self.requests[request.request_id] = request
if self.log_stats:
request.record_event(EngineCoreEventType.QUEUED)
def finish_requests(
self, request_ids: str | Iterable[str], finished_status: RequestStatus
@@ -1569,6 +1665,8 @@ class Scheduler(SchedulerInterface):
if request.status == RequestStatus.RUNNING:
running_requests_to_remove.add(request)
else:
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
self.num_waiting_for_streaming_input -= 1
waiting_requests_to_remove.append(request)
# Remove all requests from queues at once for better efficiency
@@ -1603,7 +1701,8 @@ class Scheduler(SchedulerInterface):
del self.requests[request.request_id]
def get_num_unfinished_requests(self) -> int:
return len(self.waiting) + len(self.running)
num_waiting = len(self.waiting) - self.num_waiting_for_streaming_input
return num_waiting + len(self.running)
def has_finished_requests(self) -> bool:
return len(self.finished_req_ids) > 0

View File

@@ -75,6 +75,7 @@ class EngineCoreRequest(
priority: int = 0
trace_headers: Mapping[str, str] | None = None
resumable: bool = False
# The user-provided request ID. This field is set internally,
# copied from the provided request_id that's originally assigned

View File

@@ -7,11 +7,13 @@ import time
import warnings
from collections.abc import AsyncGenerator, Iterable, Mapping
from copy import copy
from typing import Any, cast
from dataclasses import dataclass
from typing import Any
import torch
import vllm.envs as envs
from vllm import TokensPrompt
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
@@ -20,11 +22,11 @@ from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.renderers import RendererLike
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike
from vllm.tracing import init_tracer
@@ -38,6 +40,7 @@ from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.v1.engine.input_processor import InputProcessor
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.utils import get_prompt_text
from vllm.v1.executor import Executor
from vllm.v1.metrics.loggers import (
StatLoggerFactory,
@@ -50,6 +53,30 @@ from vllm.v1.metrics.stats import IterationStats
logger = init_logger(__name__)
@dataclass
class StreamingInput:
"""Input data for a streaming generation request.
This is used with generate() to support multi-turn streaming sessions
where inputs are provided via an async generator.
"""
prompt: PromptType
sampling_params: SamplingParams | None = None
class InputStreamError(Exception):
"""Wrapper for errors from the input stream generator.
This is used to propagate errors from the user's input generator
without wrapping them in EngineGenerateError.
"""
def __init__(self, cause: Exception):
self.cause = cause
super().__init__(str(cause))
class AsyncLLM(EngineClient):
def __init__(
self,
@@ -261,7 +288,7 @@ class AsyncLLM(EngineClient):
async def add_request(
self,
request_id: str,
prompt: EngineCoreRequest | PromptType,
prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
lora_request: LoRARequest | None = None,
@@ -297,6 +324,20 @@ class AsyncLLM(EngineClient):
tokenization_kwargs,
)
if isinstance(prompt, AsyncGenerator):
# Streaming input case.
return await self._add_streaming_input_request(
request_id,
prompt,
params,
arrival_time,
lora_request,
tokenization_kwargs,
trace_headers,
priority,
data_parallel_rank,
)
# Convert Input --> Request.
if isinstance(prompt, EngineCoreRequest):
request = prompt
@@ -322,10 +363,7 @@ class AsyncLLM(EngineClient):
priority,
data_parallel_rank,
)
if isinstance(prompt, str):
prompt_text = prompt
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
prompt_text = get_prompt_text(prompt)
self.input_processor.assign_request_id(request)
@@ -380,6 +418,104 @@ class AsyncLLM(EngineClient):
if self.log_requests:
logger.info("Added request %s.", request.request_id)
async def _add_streaming_input_request(
self,
request_id: str,
input_stream: AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
lora_request: LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
data_parallel_rank: int | None = None,
) -> RequestOutputCollector:
self._validate_streaming_input_sampling_params(sampling_params)
inputs = dict(
arrival_time=arrival_time,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
data_parallel_rank=data_parallel_rank,
)
if not sampling_params.skip_clone:
sampling_params = sampling_params.clone()
sampling_params.skip_clone = True
# Create request for validation, also used as the finished signal
# once the input stream is closed.
final_req = self.input_processor.process_inputs(
request_id=request_id,
prompt=TokensPrompt(prompt_token_ids=[0]),
params=sampling_params,
**inputs, # type: ignore[arg-type]
)
self.input_processor.assign_request_id(final_req)
internal_req_id = final_req.request_id
queue = RequestOutputCollector(sampling_params.output_kind, internal_req_id)
async def handle_inputs():
cancelled = False
try:
async for input_chunk in input_stream:
sp = input_chunk.sampling_params
if sp:
self._validate_streaming_input_sampling_params(sp)
else:
sp = sampling_params
req = self.input_processor.process_inputs(
request_id=internal_req_id,
prompt=input_chunk.prompt,
params=sp,
resumable=True,
**inputs, # type: ignore[arg-type]
)
req.external_req_id = request_id
if req.prompt_embeds is not None:
raise ValueError(
"prompt_embeds not supported for streaming inputs"
)
prompt_text = get_prompt_text(input_chunk.prompt)
await self._add_request(req, prompt_text, None, 0, queue)
except (asyncio.CancelledError, GeneratorExit):
cancelled = True
except Exception as error:
# Wrap in InputStreamError so generate() can propagate it
# without wrapping in EngineGenerateError.
queue.put(InputStreamError(error))
finally:
queue._input_stream_task = None
if not cancelled:
# Send empty final request to indicate that inputs have
# finished. Don't send if cancelled (session was aborted).
await self._add_request(final_req, None, None, 0, queue)
# Ensure output handler is running.
self._run_output_handler()
queue._input_stream_task = asyncio.create_task(handle_inputs())
return queue
@staticmethod
def _validate_streaming_input_sampling_params(
params: SamplingParams | PoolingParams,
):
if (
not isinstance(params, SamplingParams)
or params.n > 1
or params.output_kind == RequestOutputKind.FINAL_ONLY
or params.stop
):
raise ValueError(
"Input streaming not currently supported "
"for pooling models, n > 1, request_kind = FINAL_ONLY "
"or with stop strings."
)
# TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion
# requests we don't need to send multiple messages to core proc,
@@ -387,7 +523,7 @@ class AsyncLLM(EngineClient):
# re-multiplexed in the API server anyhow.
async def generate(
self,
prompt: EngineCoreRequest | PromptType,
prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams,
request_id: str,
*,
@@ -437,9 +573,10 @@ class AsyncLLM(EngineClient):
# Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished.
finished = out.finished
assert isinstance(out, RequestOutput)
yield out
finished = out.finished
if out is not STREAM_FINISHED:
yield out
# If the request is disconnected by the client, generate()
# is cancelled or the generator is garbage collected. So,
@@ -463,6 +600,14 @@ class AsyncLLM(EngineClient):
logger.info("Request %s failed (bad request): %s.", request_id, e)
raise
# Error from input stream generator - propagate directly.
except InputStreamError as e:
if q is not None:
await self.abort(q.request_id, internal=True)
if self.log_requests:
logger.info("Request %s failed (input error): %s.", request_id, e)
raise e.cause from e
# Unexpected error in the generate() task (possibly recoverable).
except Exception as e:
if q is not None:
@@ -478,6 +623,9 @@ class AsyncLLM(EngineClient):
)
logger.info("Request %s failed due to %s.", request_id, s)
raise EngineGenerateError() from e
finally:
if q is not None:
q.close()
def _run_output_handler(self):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
@@ -703,6 +851,9 @@ class AsyncLLM(EngineClient):
if self.log_requests:
logger.info("Request %s failed.", request_id)
raise EngineGenerateError() from e
finally:
if q is not None:
q.close()
@property
def tokenizer(self) -> TokenizerLike | None:

View File

@@ -459,6 +459,7 @@ class InputProcessor:
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
data_parallel_rank: int | None = None,
resumable: bool = False,
) -> EngineCoreRequest:
self._validate_lora(lora_request)
self._validate_params(params)
@@ -603,6 +604,7 @@ class InputProcessor:
priority=priority,
data_parallel_rank=data_parallel_rank,
trace_headers=trace_headers,
resumable=resumable,
)
def _validate_model_inputs(

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from collections import defaultdict
from collections import defaultdict, deque
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, cast
@@ -12,6 +12,7 @@ import torch
from vllm.lora.request import LoRARequest
from vllm.outputs import (
STREAM_FINISHED,
CompletionOutput,
PoolingOutput,
PoolingRequestOutput,
@@ -51,6 +52,8 @@ class RequestOutputCollector:
self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
self.ready = asyncio.Event()
self._input_stream_task: asyncio.Task | None = None
def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None:
"""Non-blocking put operation."""
if self.output is None or isinstance(output, Exception):
@@ -87,6 +90,16 @@ class RequestOutputCollector:
raise output
return output
def close(self):
if self._input_stream_task is not None:
self._input_stream_task.cancel()
self._input_stream_task = None
def __del__(self):
if (task := self._input_stream_task) is not None:
task.get_loop().call_soon_threadsafe(task.cancel)
self._input_stream_task = None
@dataclass
class OutputProcessorOutput:
@@ -94,6 +107,20 @@ class OutputProcessorOutput:
reqs_to_abort: list[str]
@dataclass
class StreamingUpdate:
"""Streaming input update data for output processor.
Contains the incremental prompt data to be applied to a request state
when the current sub-request completes.
"""
prompt: str | None
prompt_token_ids: list[int] | None
arrival_time: float
final: bool = False
class RequestState:
def __init__(
self,
@@ -116,6 +143,7 @@ class RequestState:
top_p: float | None = None,
n: int | None = None,
temperature: float | None = None,
stream_input: bool = False,
):
self.request_id = request_id
self.external_req_id = external_req_id
@@ -146,6 +174,31 @@ class RequestState:
self.stream_interval = stream_interval
self.sent_tokens_offset = 0 # Offset of sent tokens
# Streaming input queue
self.streaming_input = stream_input
self.input_chunk_queue: deque[StreamingUpdate] | None = (
deque() if stream_input else None
)
def apply_streaming_update(self, update: StreamingUpdate) -> None:
# Apply the update to the request state.
self.streaming_input = not update.final
# TODO also include relevant output tokens in new prompt here
# (match scheduler behavior).
if update.prompt:
self.prompt = (
(self.prompt + update.prompt) if self.prompt else update.prompt
)
if self.prompt_token_ids:
self.prompt_token_ids.extend(update.prompt_token_ids or ())
else:
self.prompt_token_ids = update.prompt_token_ids or []
assert self.prompt_token_ids is not None
self.prompt_len = len(self.prompt_token_ids)
if self.stats is not None:
self.stats.arrival_time = update.arrival_time
self.is_prefilling = True
@classmethod
def from_new_request(
cls,
@@ -205,6 +258,7 @@ class RequestState:
queue=queue,
log_stats=log_stats,
stream_interval=stream_interval,
stream_input=request.resumable,
)
def make_request_output(
@@ -405,7 +459,6 @@ class OutputProcessor:
a parent request, in which case the associated child requests are aborted
also.
"""
internal_req_ids = []
for request_id in request_ids:
if internal:
@@ -464,8 +517,10 @@ class OutputProcessor:
queue: RequestOutputCollector | None = None,
) -> None:
request_id = request.request_id
if request_id in self.request_states:
raise ValueError(f"Request id {request_id} already running.")
req_state = self.request_states.get(request_id)
if req_state is not None:
self._update_streaming_request_state(req_state, request, prompt)
return
req_state = RequestState.from_new_request(
tokenizer=self.tokenizer,
@@ -486,6 +541,39 @@ class OutputProcessor:
# Track the external_req_id -> [internal_req_id, ...] mapping
self.external_req_ids[req_state.external_req_id].append(request_id)
def _update_streaming_request_state(
self, req_state: RequestState, request: EngineCoreRequest, prompt: str | None
) -> None:
"""Queue a streaming update instead of immediately applying it."""
if not request.resumable:
# Final request - just mark completion, don't add its dummy tokens.
if req_state.input_chunk_queue is None:
# Engine already finished - emit final output and clean up.
self._finish_request(req_state)
if req_state.queue is not None:
# Emit a final output with finished=True
# to unblock the generate() loop.
req_state.queue.put(STREAM_FINISHED)
elif req_state.input_chunk_queue:
req_state.input_chunk_queue[-1].final = True
else:
req_state.streaming_input = False
return
update = StreamingUpdate(
prompt=prompt,
prompt_token_ids=request.prompt_token_ids,
arrival_time=request.arrival_time,
)
# Apply request updates now if the last input already completed.
if req_state.input_chunk_queue is None:
req_state.apply_streaming_update(update)
req_state.input_chunk_queue = deque()
else:
# Queue the streaming update otherwise.
req_state.input_chunk_queue.append(update)
def process_outputs(
self,
engine_core_outputs: list[EngineCoreOutput],
@@ -561,6 +649,9 @@ class OutputProcessor:
kv_transfer_params,
routed_experts,
):
if req_state.streaming_input:
request_output.finished = False
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().
req_state.queue.put(request_output)
@@ -570,36 +661,48 @@ class OutputProcessor:
# Free completed requests.
if finish_reason is not None:
self.request_states.pop(req_id)
if req_state.streaming_input:
if req_state.input_chunk_queue:
update = req_state.input_chunk_queue.popleft()
req_state.apply_streaming_update(update)
else:
req_state.input_chunk_queue = None
else:
self._finish_request(req_state)
if not engine_core_output.finished:
# If req not finished in EngineCore, but Detokenizer
# detected stop string, abort needed in EngineCore.
reqs_to_abort.append(req_id)
internal_ids = self.external_req_ids[req_state.external_req_id]
internal_ids.remove(req_id)
if not internal_ids:
del self.external_req_ids[req_state.external_req_id]
# Remove parent request if applicable.
parent_req = req_state.parent_req
if parent_req and not parent_req.child_requests:
self.parent_requests.pop(parent_req.request_id, None)
if not self.request_states:
self._requests_drained.set()
if not engine_core_output.finished:
# If req not finished in EngineCore, but Detokenizer
# detected stop string, abort needed in EngineCore.
reqs_to_abort.append(req_id)
# Track per-request stats
self._update_stats_from_finished(
req_state, finish_reason, iteration_stats
)
if self.tracer:
self.do_tracing(engine_core_output, req_state, iteration_stats)
# Track per-request stats
self._update_stats_from_finished(
req_state, finish_reason, iteration_stats
)
if self.tracer:
self.do_tracing(engine_core_output, req_state, iteration_stats)
return OutputProcessorOutput(
request_outputs=request_outputs,
reqs_to_abort=reqs_to_abort,
)
def _finish_request(self, req_state: RequestState) -> None:
req_id = req_state.request_id
self.request_states.pop(req_id)
internal_ids = self.external_req_ids[req_state.external_req_id]
internal_ids.remove(req_id)
if not internal_ids:
del self.external_req_ids[req_state.external_req_id]
# Remove parent request if applicable.
parent_req = req_state.parent_req
if parent_req and not parent_req.child_requests:
self.parent_requests.pop(parent_req.request_id, None)
if not self.request_states:
self._requests_drained.set()
def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
self.lora_states.update_scheduler_stats(scheduler_stats)

View File

@@ -4,12 +4,12 @@
import contextlib
import os
import weakref
from collections.abc import Callable, Iterator
from collections.abc import Callable, Iterator, Mapping
from dataclasses import dataclass
from enum import Enum, auto
from multiprocessing import Process, connection
from multiprocessing.process import BaseProcess
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, cast
from unittest.mock import patch
import msgspec
@@ -224,6 +224,14 @@ def get_device_indices(
return value
def get_prompt_text(prompt: Any) -> str | None:
if isinstance(prompt, str):
return prompt
if isinstance(prompt, Mapping):
return cast(str | None, prompt.get("prompt"))
return None
class CoreEngineActorManager:
"""
Utility class to handle creation, readiness, and shutdown

View File

@@ -3,7 +3,9 @@
import enum
import time
from collections import deque
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Any, Optional
@@ -27,6 +29,33 @@ if TYPE_CHECKING:
from vllm.v1.core.kv_cache_utils import BlockHash
@dataclass
class StreamingUpdate:
"""Lightweight data for streaming session continuation.
Contains only the fields needed to update an existing streaming session
with new input data.
"""
mm_features: list[MultiModalFeatureSpec] | None
prompt_token_ids: list[int] | None
max_tokens: int
arrival_time: float
sampling_params: SamplingParams | None
@classmethod
def from_request(cls, request: "Request") -> "StreamingUpdate | None":
if not request.resumable:
return None
return cls(
mm_features=request.mm_features,
prompt_token_ids=request.prompt_token_ids,
max_tokens=request.max_tokens,
arrival_time=request.arrival_time,
sampling_params=request.sampling_params,
)
class Request:
def __init__(
self,
@@ -44,6 +73,7 @@ class Request:
priority: int = 0,
trace_headers: Mapping[str, str] | None = None,
block_hasher: Callable[["Request"], list["BlockHash"]] | None = None,
resumable: bool = False,
) -> None:
self.request_id = request_id
self.client_index = client_index
@@ -105,8 +135,6 @@ class Request:
# Multi-modal related
self.mm_features = mm_features or []
self.num_encoder_inputs = len(self.mm_features)
self.has_encoder_inputs = self.num_encoder_inputs > 0
# Read-only views
# Prevent directly appending to these lists since
@@ -137,6 +165,11 @@ class Request:
self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()
# Used for streaming
self.resumable = resumable
# None entry in the queue means finished.
self.streaming_queue: deque[StreamingUpdate | None] | None = None
@classmethod
def from_engine_core_request(
cls,
@@ -158,6 +191,7 @@ class Request:
priority=request.priority,
trace_headers=request.trace_headers,
block_hasher=block_hasher,
resumable=request.resumable,
)
def append_output_token_ids(
@@ -190,6 +224,14 @@ class Request:
def num_output_tokens(self) -> int:
return len(self._output_token_ids)
@property
def num_encoder_inputs(self) -> int:
return len(self.mm_features)
@property
def has_encoder_inputs(self) -> bool:
return self.num_encoder_inputs > 0
def get_skip_reading_prefix_cache(self) -> bool:
if (
self.sampling_params is not None
@@ -246,6 +288,7 @@ class RequestStatus(enum.IntEnum):
WAITING = enum.auto()
WAITING_FOR_FSM = enum.auto()
WAITING_FOR_REMOTE_KVS = enum.auto()
WAITING_FOR_STREAMING_REQ = enum.auto()
RUNNING = enum.auto()
PREEMPTED = enum.auto()
# Note: anything after PREEMPTED will be considered
@@ -256,7 +299,7 @@ class RequestStatus(enum.IntEnum):
FINISHED_IGNORED = enum.auto()
FINISHED_ERROR = enum.auto()
def __str__(self):
def __str__(self) -> str:
return self.name
@staticmethod
@@ -278,4 +321,5 @@ _FINISHED_REASON_MAP = {
RequestStatus.FINISHED_ABORTED: FinishReason.ABORT,
RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
RequestStatus.FINISHED_ERROR: FinishReason.ERROR,
RequestStatus.WAITING_FOR_STREAMING_REQ: FinishReason.STOP,
}

View File

@@ -112,6 +112,7 @@ from vllm.v1.attention.backends.utils import (
get_dcp_local_seq_lens,
reorder_batch_to_split_decodes_and_prefills,
)
from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import (
AttentionSpec,
@@ -903,6 +904,12 @@ class GPUModelRunner(
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
if req_id in self.requests:
# For streaming case only.
req_state = self._update_streaming_request(req_id, new_req_data)
reqs_to_add.append(req_state)
continue
sampling_params = new_req_data.sampling_params
pooling_params = new_req_data.pooling_params
@@ -1133,6 +1140,40 @@ class GPUModelRunner(
self.model.get_mamba_state_copy_func(),
)
def _update_streaming_request(
self, req_id: str, new_req_data: NewRequestData
) -> CachedRequestState:
"""Updates streaming session request from `scheduled_new_reqs`.
Removes the request from InputBatch (if present), updates the cached
state, and prepares it for re-addition to the batch.
NOTE: prompt_token_ids includes intermediate output tokens - tokens
previously generated but now are input context (part of the prompt).
"""
self.input_batch.remove_request(req_id)
req_state = self.requests[req_id]
req_state.prompt_token_ids = new_req_data.prompt_token_ids
req_state.mm_features = new_req_data.mm_features
req_state.prompt_embeds = new_req_data.prompt_embeds
req_state.sampling_params = new_req_data.sampling_params
req_state.pooling_params = new_req_data.pooling_params
req_state.block_ids = new_req_data.block_ids
req_state.num_computed_tokens = new_req_data.num_computed_tokens
req_state.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
req_state.prompt_token_ids, req_state.prompt_embeds
)
# Clear `output_token_ids` as previous output tokens are now part of
# `prompt_token_ids`.
req_state.output_token_ids.clear()
if self.uses_mrope:
self._init_mrope_positions(req_state)
return req_state
def _init_mrope_positions(self, req_state: CachedRequestState):
model = self.get_model()
assert supports_mrope(model), "M-RoPE support is not implemented."