[Feature] add session based streaming input support to v1 (#28973)
Signed-off-by: Joshua Deng <joshuakdeng@gmail.com> Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Signed-off-by: Nick Hill <nickhill123@gmail.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -192,6 +192,16 @@ class RequestOutput:
|
||||
)
|
||||
|
||||
|
||||
# Sentinel to indicate request is finished, used with streaming inputs.
|
||||
STREAM_FINISHED = RequestOutput(
|
||||
request_id="",
|
||||
prompt=None,
|
||||
prompt_token_ids=None,
|
||||
prompt_logprobs=None,
|
||||
outputs=[],
|
||||
finished=True,
|
||||
)
|
||||
|
||||
_O = TypeVar("_O", default=PoolingOutput)
|
||||
|
||||
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import replace
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
@@ -49,12 +50,9 @@ from vllm.v1.core.sched.utils import check_stop, remove_all
|
||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
|
||||
from vllm.v1.metrics.perf import ModelMetrics, PerfStats
|
||||
from vllm.v1.metrics.stats import (
|
||||
PrefixCacheStats,
|
||||
SchedulerStats,
|
||||
)
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
|
||||
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.request import Request, RequestStatus, StreamingUpdate
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
from vllm.v1.utils import record_function_or_nullcontext
|
||||
@@ -166,6 +164,10 @@ class Scheduler(SchedulerInterface):
|
||||
# This is flushed at the end of each scheduling step.
|
||||
self.finished_req_ids: set[str] = set()
|
||||
|
||||
# Counter for requests waiting for streaming input. Used to calculate
|
||||
# number of unfinished requests
|
||||
self.num_waiting_for_streaming_input: int = 0
|
||||
|
||||
# KV Connector: requests in process of async KV loading or recving
|
||||
self.finished_recving_kv_req_ids: set[str] = set()
|
||||
self.failed_recving_kv_req_ids: set[str] = set()
|
||||
@@ -569,6 +571,13 @@ class Scheduler(SchedulerInterface):
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
# Streaming: skip request if still waiting for next streaming req.
|
||||
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
|
||||
assert not request.streaming_queue
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
# Check that adding the request still respects the max_loras
|
||||
# constraint.
|
||||
if (
|
||||
@@ -929,6 +938,51 @@ class Scheduler(SchedulerInterface):
|
||||
# it will also affect the scheduler output.
|
||||
self.finished_req_ids = set()
|
||||
|
||||
def _update_request_as_session(
|
||||
self, session: Request, update: StreamingUpdate
|
||||
) -> None:
|
||||
"""
|
||||
Updates the waiting session with the next streaming update.
|
||||
|
||||
Discards the last sampled output token from the prior input chunk.
|
||||
"""
|
||||
|
||||
# Current streaming input behaviour: Keep only computed output tokens
|
||||
# (discard final sampled output token).
|
||||
num_computed_tokens = session.num_computed_tokens
|
||||
kept_output_tokens = session._all_token_ids[
|
||||
session.num_prompt_tokens : num_computed_tokens
|
||||
]
|
||||
del session._all_token_ids[num_computed_tokens:]
|
||||
session._output_token_ids.clear()
|
||||
assert session.prompt_token_ids is not None
|
||||
# Extend prompt with kept output tokens.
|
||||
session.prompt_token_ids.extend(kept_output_tokens)
|
||||
|
||||
if update.mm_features:
|
||||
base = session.num_tokens
|
||||
for mm_feature in update.mm_features:
|
||||
mm_feature.mm_position = replace(
|
||||
mm_feature.mm_position, offset=mm_feature.mm_position.offset + base
|
||||
)
|
||||
session.mm_features.extend(update.mm_features)
|
||||
|
||||
session._all_token_ids.extend(update.prompt_token_ids or ())
|
||||
session.prompt_token_ids.extend(update.prompt_token_ids or ())
|
||||
# Update block hashes for the new tokens
|
||||
# (mirrors Request.append_output_token_ids)
|
||||
if session.get_hash_new_full_blocks is not None:
|
||||
session.block_hashes.extend(session.get_hash_new_full_blocks())
|
||||
session.num_prompt_tokens = len(session.prompt_token_ids)
|
||||
session.arrival_time = update.arrival_time
|
||||
session.sampling_params = update.sampling_params
|
||||
if session.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
|
||||
self.num_waiting_for_streaming_input -= 1
|
||||
session.status = RequestStatus.WAITING
|
||||
|
||||
if self.log_stats:
|
||||
session.record_event(EngineCoreEventType.QUEUED)
|
||||
|
||||
def _make_cached_request_data(
|
||||
self,
|
||||
running_reqs: list[Request],
|
||||
@@ -1271,9 +1325,17 @@ class Scheduler(SchedulerInterface):
|
||||
stopped = True
|
||||
|
||||
routed_experts = None
|
||||
finish_reason = None
|
||||
if stopped:
|
||||
routed_experts = self._get_routed_experts(request)
|
||||
kv_transfer_params = self._free_request(request)
|
||||
|
||||
# Capture finish_reason BEFORE _handle_stopped_request, which may
|
||||
# reset the status to WAITING for streaming requests that continue.
|
||||
finish_reason = request.get_finished_reason()
|
||||
finished = self._handle_stopped_request(request)
|
||||
if finished:
|
||||
kv_transfer_params = self._free_request(request)
|
||||
|
||||
if status_before_stop == RequestStatus.RUNNING:
|
||||
stopped_running_reqs.add(request)
|
||||
else:
|
||||
@@ -1315,7 +1377,7 @@ class Scheduler(SchedulerInterface):
|
||||
EngineCoreOutput(
|
||||
request_id=req_id,
|
||||
new_token_ids=new_token_ids,
|
||||
finish_reason=request.get_finished_reason(),
|
||||
finish_reason=finish_reason,
|
||||
new_logprobs=new_logprobs,
|
||||
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
||||
pooling_output=pooler_output,
|
||||
@@ -1410,6 +1472,24 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
return engine_core_outputs
|
||||
|
||||
def _handle_stopped_request(self, request: Request) -> bool:
|
||||
"""Return True if finished (can be False for resumable requests)."""
|
||||
if not request.resumable:
|
||||
return True
|
||||
|
||||
if request.streaming_queue:
|
||||
update = request.streaming_queue.popleft()
|
||||
if update is None:
|
||||
# Streaming request finished.
|
||||
return True
|
||||
self._update_request_as_session(request, update)
|
||||
else:
|
||||
request.status = RequestStatus.WAITING_FOR_STREAMING_REQ
|
||||
self.num_waiting_for_streaming_input += 1
|
||||
|
||||
self.waiting.add_request(request)
|
||||
return False
|
||||
|
||||
def _get_routed_experts(self, request: Request) -> np.ndarray | None:
|
||||
if not self.vllm_config.model_config.enable_return_routed_experts:
|
||||
return None
|
||||
@@ -1535,10 +1615,26 @@ class Scheduler(SchedulerInterface):
|
||||
return len(self.running), len(self.waiting)
|
||||
|
||||
def add_request(self, request: Request) -> None:
|
||||
self.waiting.add_request(request)
|
||||
self.requests[request.request_id] = request
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.QUEUED)
|
||||
existing = self.requests.get(request.request_id)
|
||||
if existing is not None:
|
||||
update = StreamingUpdate.from_request(request)
|
||||
if existing.status != RequestStatus.WAITING_FOR_STREAMING_REQ:
|
||||
assert existing.streaming_queue is not None, "duplicate request id"
|
||||
# Queue next input chunk (or finished sentinel).
|
||||
existing.streaming_queue.append(update)
|
||||
elif update is not None:
|
||||
# Commence next input chunk.
|
||||
self._update_request_as_session(existing, update)
|
||||
else:
|
||||
# Streaming-input session finished.
|
||||
self.finish_requests(request.request_id, RequestStatus.FINISHED_ABORTED)
|
||||
else:
|
||||
if request.resumable:
|
||||
request.streaming_queue = deque()
|
||||
self.waiting.add_request(request)
|
||||
self.requests[request.request_id] = request
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.QUEUED)
|
||||
|
||||
def finish_requests(
|
||||
self, request_ids: str | Iterable[str], finished_status: RequestStatus
|
||||
@@ -1569,6 +1665,8 @@ class Scheduler(SchedulerInterface):
|
||||
if request.status == RequestStatus.RUNNING:
|
||||
running_requests_to_remove.add(request)
|
||||
else:
|
||||
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
|
||||
self.num_waiting_for_streaming_input -= 1
|
||||
waiting_requests_to_remove.append(request)
|
||||
|
||||
# Remove all requests from queues at once for better efficiency
|
||||
@@ -1603,7 +1701,8 @@ class Scheduler(SchedulerInterface):
|
||||
del self.requests[request.request_id]
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
return len(self.waiting) + len(self.running)
|
||||
num_waiting = len(self.waiting) - self.num_waiting_for_streaming_input
|
||||
return num_waiting + len(self.running)
|
||||
|
||||
def has_finished_requests(self) -> bool:
|
||||
return len(self.finished_req_ids) > 0
|
||||
|
||||
@@ -75,6 +75,7 @@ class EngineCoreRequest(
|
||||
priority: int = 0
|
||||
|
||||
trace_headers: Mapping[str, str] | None = None
|
||||
resumable: bool = False
|
||||
|
||||
# The user-provided request ID. This field is set internally,
|
||||
# copied from the provided request_id that's originally assigned
|
||||
|
||||
@@ -7,11 +7,13 @@ import time
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator, Iterable, Mapping
|
||||
from copy import copy
|
||||
from typing import Any, cast
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import TokensPrompt
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
@@ -20,11 +22,11 @@ from vllm.inputs import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import RendererLike
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tracing import init_tracer
|
||||
@@ -38,6 +40,7 @@ from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
|
||||
from vllm.v1.engine.input_processor import InputProcessor
|
||||
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.engine.utils import get_prompt_text
|
||||
from vllm.v1.executor import Executor
|
||||
from vllm.v1.metrics.loggers import (
|
||||
StatLoggerFactory,
|
||||
@@ -50,6 +53,30 @@ from vllm.v1.metrics.stats import IterationStats
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamingInput:
|
||||
"""Input data for a streaming generation request.
|
||||
|
||||
This is used with generate() to support multi-turn streaming sessions
|
||||
where inputs are provided via an async generator.
|
||||
"""
|
||||
|
||||
prompt: PromptType
|
||||
sampling_params: SamplingParams | None = None
|
||||
|
||||
|
||||
class InputStreamError(Exception):
|
||||
"""Wrapper for errors from the input stream generator.
|
||||
|
||||
This is used to propagate errors from the user's input generator
|
||||
without wrapping them in EngineGenerateError.
|
||||
"""
|
||||
|
||||
def __init__(self, cause: Exception):
|
||||
self.cause = cause
|
||||
super().__init__(str(cause))
|
||||
|
||||
|
||||
class AsyncLLM(EngineClient):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -261,7 +288,7 @@ class AsyncLLM(EngineClient):
|
||||
async def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: EngineCoreRequest | PromptType,
|
||||
prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
|
||||
params: SamplingParams | PoolingParams,
|
||||
arrival_time: float | None = None,
|
||||
lora_request: LoRARequest | None = None,
|
||||
@@ -297,6 +324,20 @@ class AsyncLLM(EngineClient):
|
||||
tokenization_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(prompt, AsyncGenerator):
|
||||
# Streaming input case.
|
||||
return await self._add_streaming_input_request(
|
||||
request_id,
|
||||
prompt,
|
||||
params,
|
||||
arrival_time,
|
||||
lora_request,
|
||||
tokenization_kwargs,
|
||||
trace_headers,
|
||||
priority,
|
||||
data_parallel_rank,
|
||||
)
|
||||
|
||||
# Convert Input --> Request.
|
||||
if isinstance(prompt, EngineCoreRequest):
|
||||
request = prompt
|
||||
@@ -322,10 +363,7 @@ class AsyncLLM(EngineClient):
|
||||
priority,
|
||||
data_parallel_rank,
|
||||
)
|
||||
if isinstance(prompt, str):
|
||||
prompt_text = prompt
|
||||
elif isinstance(prompt, Mapping):
|
||||
prompt_text = cast(str | None, prompt.get("prompt"))
|
||||
prompt_text = get_prompt_text(prompt)
|
||||
|
||||
self.input_processor.assign_request_id(request)
|
||||
|
||||
@@ -380,6 +418,104 @@ class AsyncLLM(EngineClient):
|
||||
if self.log_requests:
|
||||
logger.info("Added request %s.", request.request_id)
|
||||
|
||||
async def _add_streaming_input_request(
|
||||
self,
|
||||
request_id: str,
|
||||
input_stream: AsyncGenerator[StreamingInput, None],
|
||||
sampling_params: SamplingParams | PoolingParams,
|
||||
arrival_time: float | None = None,
|
||||
lora_request: LoRARequest | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: int | None = None,
|
||||
) -> RequestOutputCollector:
|
||||
self._validate_streaming_input_sampling_params(sampling_params)
|
||||
|
||||
inputs = dict(
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
|
||||
if not sampling_params.skip_clone:
|
||||
sampling_params = sampling_params.clone()
|
||||
sampling_params.skip_clone = True
|
||||
|
||||
# Create request for validation, also used as the finished signal
|
||||
# once the input stream is closed.
|
||||
final_req = self.input_processor.process_inputs(
|
||||
request_id=request_id,
|
||||
prompt=TokensPrompt(prompt_token_ids=[0]),
|
||||
params=sampling_params,
|
||||
**inputs, # type: ignore[arg-type]
|
||||
)
|
||||
self.input_processor.assign_request_id(final_req)
|
||||
internal_req_id = final_req.request_id
|
||||
|
||||
queue = RequestOutputCollector(sampling_params.output_kind, internal_req_id)
|
||||
|
||||
async def handle_inputs():
|
||||
cancelled = False
|
||||
try:
|
||||
async for input_chunk in input_stream:
|
||||
sp = input_chunk.sampling_params
|
||||
if sp:
|
||||
self._validate_streaming_input_sampling_params(sp)
|
||||
else:
|
||||
sp = sampling_params
|
||||
req = self.input_processor.process_inputs(
|
||||
request_id=internal_req_id,
|
||||
prompt=input_chunk.prompt,
|
||||
params=sp,
|
||||
resumable=True,
|
||||
**inputs, # type: ignore[arg-type]
|
||||
)
|
||||
req.external_req_id = request_id
|
||||
if req.prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
"prompt_embeds not supported for streaming inputs"
|
||||
)
|
||||
prompt_text = get_prompt_text(input_chunk.prompt)
|
||||
await self._add_request(req, prompt_text, None, 0, queue)
|
||||
except (asyncio.CancelledError, GeneratorExit):
|
||||
cancelled = True
|
||||
except Exception as error:
|
||||
# Wrap in InputStreamError so generate() can propagate it
|
||||
# without wrapping in EngineGenerateError.
|
||||
queue.put(InputStreamError(error))
|
||||
finally:
|
||||
queue._input_stream_task = None
|
||||
if not cancelled:
|
||||
# Send empty final request to indicate that inputs have
|
||||
# finished. Don't send if cancelled (session was aborted).
|
||||
await self._add_request(final_req, None, None, 0, queue)
|
||||
|
||||
# Ensure output handler is running.
|
||||
self._run_output_handler()
|
||||
|
||||
queue._input_stream_task = asyncio.create_task(handle_inputs())
|
||||
return queue
|
||||
|
||||
@staticmethod
|
||||
def _validate_streaming_input_sampling_params(
|
||||
params: SamplingParams | PoolingParams,
|
||||
):
|
||||
if (
|
||||
not isinstance(params, SamplingParams)
|
||||
or params.n > 1
|
||||
or params.output_kind == RequestOutputKind.FINAL_ONLY
|
||||
or params.stop
|
||||
):
|
||||
raise ValueError(
|
||||
"Input streaming not currently supported "
|
||||
"for pooling models, n > 1, request_kind = FINAL_ONLY "
|
||||
"or with stop strings."
|
||||
)
|
||||
|
||||
# TODO: we should support multiple prompts in one call, as you
|
||||
# can do with LLM.generate. So that for multi-prompt completion
|
||||
# requests we don't need to send multiple messages to core proc,
|
||||
@@ -387,7 +523,7 @@ class AsyncLLM(EngineClient):
|
||||
# re-multiplexed in the API server anyhow.
|
||||
async def generate(
|
||||
self,
|
||||
prompt: EngineCoreRequest | PromptType,
|
||||
prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
*,
|
||||
@@ -437,9 +573,10 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
# Note: both OutputProcessor and EngineCore handle their
|
||||
# own request cleanup based on finished.
|
||||
finished = out.finished
|
||||
assert isinstance(out, RequestOutput)
|
||||
yield out
|
||||
finished = out.finished
|
||||
if out is not STREAM_FINISHED:
|
||||
yield out
|
||||
|
||||
# If the request is disconnected by the client, generate()
|
||||
# is cancelled or the generator is garbage collected. So,
|
||||
@@ -463,6 +600,14 @@ class AsyncLLM(EngineClient):
|
||||
logger.info("Request %s failed (bad request): %s.", request_id, e)
|
||||
raise
|
||||
|
||||
# Error from input stream generator - propagate directly.
|
||||
except InputStreamError as e:
|
||||
if q is not None:
|
||||
await self.abort(q.request_id, internal=True)
|
||||
if self.log_requests:
|
||||
logger.info("Request %s failed (input error): %s.", request_id, e)
|
||||
raise e.cause from e
|
||||
|
||||
# Unexpected error in the generate() task (possibly recoverable).
|
||||
except Exception as e:
|
||||
if q is not None:
|
||||
@@ -478,6 +623,9 @@ class AsyncLLM(EngineClient):
|
||||
)
|
||||
logger.info("Request %s failed due to %s.", request_id, s)
|
||||
raise EngineGenerateError() from e
|
||||
finally:
|
||||
if q is not None:
|
||||
q.close()
|
||||
|
||||
def _run_output_handler(self):
|
||||
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
|
||||
@@ -703,6 +851,9 @@ class AsyncLLM(EngineClient):
|
||||
if self.log_requests:
|
||||
logger.info("Request %s failed.", request_id)
|
||||
raise EngineGenerateError() from e
|
||||
finally:
|
||||
if q is not None:
|
||||
q.close()
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> TokenizerLike | None:
|
||||
|
||||
@@ -459,6 +459,7 @@ class InputProcessor:
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: int | None = None,
|
||||
resumable: bool = False,
|
||||
) -> EngineCoreRequest:
|
||||
self._validate_lora(lora_request)
|
||||
self._validate_params(params)
|
||||
@@ -603,6 +604,7 @@ class InputProcessor:
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
trace_headers=trace_headers,
|
||||
resumable=resumable,
|
||||
)
|
||||
|
||||
def _validate_model_inputs(
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
@@ -12,6 +12,7 @@ import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import (
|
||||
STREAM_FINISHED,
|
||||
CompletionOutput,
|
||||
PoolingOutput,
|
||||
PoolingRequestOutput,
|
||||
@@ -51,6 +52,8 @@ class RequestOutputCollector:
|
||||
self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
|
||||
self.ready = asyncio.Event()
|
||||
|
||||
self._input_stream_task: asyncio.Task | None = None
|
||||
|
||||
def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None:
|
||||
"""Non-blocking put operation."""
|
||||
if self.output is None or isinstance(output, Exception):
|
||||
@@ -87,6 +90,16 @@ class RequestOutputCollector:
|
||||
raise output
|
||||
return output
|
||||
|
||||
def close(self):
|
||||
if self._input_stream_task is not None:
|
||||
self._input_stream_task.cancel()
|
||||
self._input_stream_task = None
|
||||
|
||||
def __del__(self):
|
||||
if (task := self._input_stream_task) is not None:
|
||||
task.get_loop().call_soon_threadsafe(task.cancel)
|
||||
self._input_stream_task = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputProcessorOutput:
|
||||
@@ -94,6 +107,20 @@ class OutputProcessorOutput:
|
||||
reqs_to_abort: list[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamingUpdate:
|
||||
"""Streaming input update data for output processor.
|
||||
|
||||
Contains the incremental prompt data to be applied to a request state
|
||||
when the current sub-request completes.
|
||||
"""
|
||||
|
||||
prompt: str | None
|
||||
prompt_token_ids: list[int] | None
|
||||
arrival_time: float
|
||||
final: bool = False
|
||||
|
||||
|
||||
class RequestState:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -116,6 +143,7 @@ class RequestState:
|
||||
top_p: float | None = None,
|
||||
n: int | None = None,
|
||||
temperature: float | None = None,
|
||||
stream_input: bool = False,
|
||||
):
|
||||
self.request_id = request_id
|
||||
self.external_req_id = external_req_id
|
||||
@@ -146,6 +174,31 @@ class RequestState:
|
||||
self.stream_interval = stream_interval
|
||||
self.sent_tokens_offset = 0 # Offset of sent tokens
|
||||
|
||||
# Streaming input queue
|
||||
self.streaming_input = stream_input
|
||||
self.input_chunk_queue: deque[StreamingUpdate] | None = (
|
||||
deque() if stream_input else None
|
||||
)
|
||||
|
||||
def apply_streaming_update(self, update: StreamingUpdate) -> None:
|
||||
# Apply the update to the request state.
|
||||
self.streaming_input = not update.final
|
||||
# TODO also include relevant output tokens in new prompt here
|
||||
# (match scheduler behavior).
|
||||
if update.prompt:
|
||||
self.prompt = (
|
||||
(self.prompt + update.prompt) if self.prompt else update.prompt
|
||||
)
|
||||
if self.prompt_token_ids:
|
||||
self.prompt_token_ids.extend(update.prompt_token_ids or ())
|
||||
else:
|
||||
self.prompt_token_ids = update.prompt_token_ids or []
|
||||
assert self.prompt_token_ids is not None
|
||||
self.prompt_len = len(self.prompt_token_ids)
|
||||
if self.stats is not None:
|
||||
self.stats.arrival_time = update.arrival_time
|
||||
self.is_prefilling = True
|
||||
|
||||
@classmethod
|
||||
def from_new_request(
|
||||
cls,
|
||||
@@ -205,6 +258,7 @@ class RequestState:
|
||||
queue=queue,
|
||||
log_stats=log_stats,
|
||||
stream_interval=stream_interval,
|
||||
stream_input=request.resumable,
|
||||
)
|
||||
|
||||
def make_request_output(
|
||||
@@ -405,7 +459,6 @@ class OutputProcessor:
|
||||
a parent request, in which case the associated child requests are aborted
|
||||
also.
|
||||
"""
|
||||
|
||||
internal_req_ids = []
|
||||
for request_id in request_ids:
|
||||
if internal:
|
||||
@@ -464,8 +517,10 @@ class OutputProcessor:
|
||||
queue: RequestOutputCollector | None = None,
|
||||
) -> None:
|
||||
request_id = request.request_id
|
||||
if request_id in self.request_states:
|
||||
raise ValueError(f"Request id {request_id} already running.")
|
||||
req_state = self.request_states.get(request_id)
|
||||
if req_state is not None:
|
||||
self._update_streaming_request_state(req_state, request, prompt)
|
||||
return
|
||||
|
||||
req_state = RequestState.from_new_request(
|
||||
tokenizer=self.tokenizer,
|
||||
@@ -486,6 +541,39 @@ class OutputProcessor:
|
||||
# Track the external_req_id -> [internal_req_id, ...] mapping
|
||||
self.external_req_ids[req_state.external_req_id].append(request_id)
|
||||
|
||||
def _update_streaming_request_state(
|
||||
self, req_state: RequestState, request: EngineCoreRequest, prompt: str | None
|
||||
) -> None:
|
||||
"""Queue a streaming update instead of immediately applying it."""
|
||||
if not request.resumable:
|
||||
# Final request - just mark completion, don't add its dummy tokens.
|
||||
if req_state.input_chunk_queue is None:
|
||||
# Engine already finished - emit final output and clean up.
|
||||
self._finish_request(req_state)
|
||||
if req_state.queue is not None:
|
||||
# Emit a final output with finished=True
|
||||
# to unblock the generate() loop.
|
||||
req_state.queue.put(STREAM_FINISHED)
|
||||
elif req_state.input_chunk_queue:
|
||||
req_state.input_chunk_queue[-1].final = True
|
||||
else:
|
||||
req_state.streaming_input = False
|
||||
return
|
||||
|
||||
update = StreamingUpdate(
|
||||
prompt=prompt,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
arrival_time=request.arrival_time,
|
||||
)
|
||||
|
||||
# Apply request updates now if the last input already completed.
|
||||
if req_state.input_chunk_queue is None:
|
||||
req_state.apply_streaming_update(update)
|
||||
req_state.input_chunk_queue = deque()
|
||||
else:
|
||||
# Queue the streaming update otherwise.
|
||||
req_state.input_chunk_queue.append(update)
|
||||
|
||||
def process_outputs(
|
||||
self,
|
||||
engine_core_outputs: list[EngineCoreOutput],
|
||||
@@ -561,6 +649,9 @@ class OutputProcessor:
|
||||
kv_transfer_params,
|
||||
routed_experts,
|
||||
):
|
||||
if req_state.streaming_input:
|
||||
request_output.finished = False
|
||||
|
||||
if req_state.queue is not None:
|
||||
# AsyncLLM: put into queue for handling by generate().
|
||||
req_state.queue.put(request_output)
|
||||
@@ -570,36 +661,48 @@ class OutputProcessor:
|
||||
|
||||
# Free completed requests.
|
||||
if finish_reason is not None:
|
||||
self.request_states.pop(req_id)
|
||||
if req_state.streaming_input:
|
||||
if req_state.input_chunk_queue:
|
||||
update = req_state.input_chunk_queue.popleft()
|
||||
req_state.apply_streaming_update(update)
|
||||
else:
|
||||
req_state.input_chunk_queue = None
|
||||
else:
|
||||
self._finish_request(req_state)
|
||||
if not engine_core_output.finished:
|
||||
# If req not finished in EngineCore, but Detokenizer
|
||||
# detected stop string, abort needed in EngineCore.
|
||||
reqs_to_abort.append(req_id)
|
||||
|
||||
internal_ids = self.external_req_ids[req_state.external_req_id]
|
||||
internal_ids.remove(req_id)
|
||||
if not internal_ids:
|
||||
del self.external_req_ids[req_state.external_req_id]
|
||||
|
||||
# Remove parent request if applicable.
|
||||
parent_req = req_state.parent_req
|
||||
if parent_req and not parent_req.child_requests:
|
||||
self.parent_requests.pop(parent_req.request_id, None)
|
||||
if not self.request_states:
|
||||
self._requests_drained.set()
|
||||
if not engine_core_output.finished:
|
||||
# If req not finished in EngineCore, but Detokenizer
|
||||
# detected stop string, abort needed in EngineCore.
|
||||
reqs_to_abort.append(req_id)
|
||||
|
||||
# Track per-request stats
|
||||
self._update_stats_from_finished(
|
||||
req_state, finish_reason, iteration_stats
|
||||
)
|
||||
if self.tracer:
|
||||
self.do_tracing(engine_core_output, req_state, iteration_stats)
|
||||
# Track per-request stats
|
||||
self._update_stats_from_finished(
|
||||
req_state, finish_reason, iteration_stats
|
||||
)
|
||||
if self.tracer:
|
||||
self.do_tracing(engine_core_output, req_state, iteration_stats)
|
||||
|
||||
return OutputProcessorOutput(
|
||||
request_outputs=request_outputs,
|
||||
reqs_to_abort=reqs_to_abort,
|
||||
)
|
||||
|
||||
def _finish_request(self, req_state: RequestState) -> None:
|
||||
req_id = req_state.request_id
|
||||
self.request_states.pop(req_id)
|
||||
|
||||
internal_ids = self.external_req_ids[req_state.external_req_id]
|
||||
internal_ids.remove(req_id)
|
||||
if not internal_ids:
|
||||
del self.external_req_ids[req_state.external_req_id]
|
||||
|
||||
# Remove parent request if applicable.
|
||||
parent_req = req_state.parent_req
|
||||
if parent_req and not parent_req.child_requests:
|
||||
self.parent_requests.pop(parent_req.request_id, None)
|
||||
|
||||
if not self.request_states:
|
||||
self._requests_drained.set()
|
||||
|
||||
def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
|
||||
self.lora_states.update_scheduler_stats(scheduler_stats)
|
||||
|
||||
|
||||
@@ -4,12 +4,12 @@
|
||||
import contextlib
|
||||
import os
|
||||
import weakref
|
||||
from collections.abc import Callable, Iterator
|
||||
from collections.abc import Callable, Iterator, Mapping
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from multiprocessing import Process, connection
|
||||
from multiprocessing.process import BaseProcess
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import msgspec
|
||||
@@ -224,6 +224,14 @@ def get_device_indices(
|
||||
return value
|
||||
|
||||
|
||||
def get_prompt_text(prompt: Any) -> str | None:
|
||||
if isinstance(prompt, str):
|
||||
return prompt
|
||||
if isinstance(prompt, Mapping):
|
||||
return cast(str | None, prompt.get("prompt"))
|
||||
return None
|
||||
|
||||
|
||||
class CoreEngineActorManager:
|
||||
"""
|
||||
Utility class to handle creation, readiness, and shutdown
|
||||
|
||||
@@ -3,7 +3,9 @@
|
||||
|
||||
import enum
|
||||
import time
|
||||
from collections import deque
|
||||
from collections.abc import Callable, Mapping
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
@@ -27,6 +29,33 @@ if TYPE_CHECKING:
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamingUpdate:
|
||||
"""Lightweight data for streaming session continuation.
|
||||
|
||||
Contains only the fields needed to update an existing streaming session
|
||||
with new input data.
|
||||
"""
|
||||
|
||||
mm_features: list[MultiModalFeatureSpec] | None
|
||||
prompt_token_ids: list[int] | None
|
||||
max_tokens: int
|
||||
arrival_time: float
|
||||
sampling_params: SamplingParams | None
|
||||
|
||||
@classmethod
|
||||
def from_request(cls, request: "Request") -> "StreamingUpdate | None":
|
||||
if not request.resumable:
|
||||
return None
|
||||
return cls(
|
||||
mm_features=request.mm_features,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
max_tokens=request.max_tokens,
|
||||
arrival_time=request.arrival_time,
|
||||
sampling_params=request.sampling_params,
|
||||
)
|
||||
|
||||
|
||||
class Request:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -44,6 +73,7 @@ class Request:
|
||||
priority: int = 0,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
block_hasher: Callable[["Request"], list["BlockHash"]] | None = None,
|
||||
resumable: bool = False,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.client_index = client_index
|
||||
@@ -105,8 +135,6 @@ class Request:
|
||||
|
||||
# Multi-modal related
|
||||
self.mm_features = mm_features or []
|
||||
self.num_encoder_inputs = len(self.mm_features)
|
||||
self.has_encoder_inputs = self.num_encoder_inputs > 0
|
||||
|
||||
# Read-only views
|
||||
# Prevent directly appending to these lists since
|
||||
@@ -137,6 +165,11 @@ class Request:
|
||||
|
||||
self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()
|
||||
|
||||
# Used for streaming
|
||||
self.resumable = resumable
|
||||
# None entry in the queue means finished.
|
||||
self.streaming_queue: deque[StreamingUpdate | None] | None = None
|
||||
|
||||
@classmethod
|
||||
def from_engine_core_request(
|
||||
cls,
|
||||
@@ -158,6 +191,7 @@ class Request:
|
||||
priority=request.priority,
|
||||
trace_headers=request.trace_headers,
|
||||
block_hasher=block_hasher,
|
||||
resumable=request.resumable,
|
||||
)
|
||||
|
||||
def append_output_token_ids(
|
||||
@@ -190,6 +224,14 @@ class Request:
|
||||
def num_output_tokens(self) -> int:
|
||||
return len(self._output_token_ids)
|
||||
|
||||
@property
|
||||
def num_encoder_inputs(self) -> int:
|
||||
return len(self.mm_features)
|
||||
|
||||
@property
|
||||
def has_encoder_inputs(self) -> bool:
|
||||
return self.num_encoder_inputs > 0
|
||||
|
||||
def get_skip_reading_prefix_cache(self) -> bool:
|
||||
if (
|
||||
self.sampling_params is not None
|
||||
@@ -246,6 +288,7 @@ class RequestStatus(enum.IntEnum):
|
||||
WAITING = enum.auto()
|
||||
WAITING_FOR_FSM = enum.auto()
|
||||
WAITING_FOR_REMOTE_KVS = enum.auto()
|
||||
WAITING_FOR_STREAMING_REQ = enum.auto()
|
||||
RUNNING = enum.auto()
|
||||
PREEMPTED = enum.auto()
|
||||
# Note: anything after PREEMPTED will be considered
|
||||
@@ -256,7 +299,7 @@ class RequestStatus(enum.IntEnum):
|
||||
FINISHED_IGNORED = enum.auto()
|
||||
FINISHED_ERROR = enum.auto()
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
@staticmethod
|
||||
@@ -278,4 +321,5 @@ _FINISHED_REASON_MAP = {
|
||||
RequestStatus.FINISHED_ABORTED: FinishReason.ABORT,
|
||||
RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
|
||||
RequestStatus.FINISHED_ERROR: FinishReason.ERROR,
|
||||
RequestStatus.WAITING_FOR_STREAMING_REQ: FinishReason.STOP,
|
||||
}
|
||||
|
||||
@@ -112,6 +112,7 @@ from vllm.v1.attention.backends.utils import (
|
||||
get_dcp_local_seq_lens,
|
||||
reorder_batch_to_split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
@@ -903,6 +904,12 @@ class GPUModelRunner(
|
||||
# Add new requests to the cached states.
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_id = new_req_data.req_id
|
||||
if req_id in self.requests:
|
||||
# For streaming case only.
|
||||
req_state = self._update_streaming_request(req_id, new_req_data)
|
||||
reqs_to_add.append(req_state)
|
||||
continue
|
||||
|
||||
sampling_params = new_req_data.sampling_params
|
||||
pooling_params = new_req_data.pooling_params
|
||||
|
||||
@@ -1133,6 +1140,40 @@ class GPUModelRunner(
|
||||
self.model.get_mamba_state_copy_func(),
|
||||
)
|
||||
|
||||
def _update_streaming_request(
|
||||
self, req_id: str, new_req_data: NewRequestData
|
||||
) -> CachedRequestState:
|
||||
"""Updates streaming session request from `scheduled_new_reqs`.
|
||||
|
||||
Removes the request from InputBatch (if present), updates the cached
|
||||
state, and prepares it for re-addition to the batch.
|
||||
|
||||
NOTE: prompt_token_ids includes intermediate output tokens - tokens
|
||||
previously generated but now are input context (part of the prompt).
|
||||
"""
|
||||
self.input_batch.remove_request(req_id)
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
req_state.prompt_token_ids = new_req_data.prompt_token_ids
|
||||
req_state.mm_features = new_req_data.mm_features
|
||||
req_state.prompt_embeds = new_req_data.prompt_embeds
|
||||
req_state.sampling_params = new_req_data.sampling_params
|
||||
req_state.pooling_params = new_req_data.pooling_params
|
||||
req_state.block_ids = new_req_data.block_ids
|
||||
req_state.num_computed_tokens = new_req_data.num_computed_tokens
|
||||
req_state.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
req_state.prompt_token_ids, req_state.prompt_embeds
|
||||
)
|
||||
|
||||
# Clear `output_token_ids` as previous output tokens are now part of
|
||||
# `prompt_token_ids`.
|
||||
req_state.output_token_ids.clear()
|
||||
|
||||
if self.uses_mrope:
|
||||
self._init_mrope_positions(req_state)
|
||||
|
||||
return req_state
|
||||
|
||||
def _init_mrope_positions(self, req_state: CachedRequestState):
|
||||
model = self.get_model()
|
||||
assert supports_mrope(model), "M-RoPE support is not implemented."
|
||||
|
||||
Reference in New Issue
Block a user