[Responses API] Add kv_transfer_params for PD disaggregation (#37424)
Signed-off-by: bongwoobak <bongwoobak@gmail.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -9,7 +9,7 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from contextlib import AsyncExitStack
|
||||
from dataclasses import replace
|
||||
from typing import TYPE_CHECKING, Final, Union
|
||||
from typing import TYPE_CHECKING, Any, Final, Union
|
||||
|
||||
from openai.types.responses.response_function_tool_call_output_item import (
|
||||
ResponseFunctionToolCallOutputItem,
|
||||
@@ -182,6 +182,7 @@ class SimpleContext(ConversationContext):
|
||||
self.all_turn_metrics = []
|
||||
|
||||
self.input_messages: list[ResponseRawMessageAndToken] = []
|
||||
self.kv_transfer_params: dict[str, Any] | None = None
|
||||
|
||||
def append_output(self, output) -> None:
|
||||
self.last_output = output
|
||||
@@ -190,6 +191,8 @@ class SimpleContext(ConversationContext):
|
||||
self.num_prompt_tokens = len(output.prompt_token_ids or [])
|
||||
self.num_cached_tokens = output.num_cached_tokens or 0
|
||||
self.num_output_tokens += len(output.outputs[0].token_ids or [])
|
||||
if output.kv_transfer_params is not None:
|
||||
self.kv_transfer_params = output.kv_transfer_params
|
||||
|
||||
# Accumulate text, token_ids, and logprobs for streaming mode
|
||||
delta_output = output.outputs[0]
|
||||
@@ -308,11 +311,14 @@ class ParsableContext(ConversationContext):
|
||||
self.input_messages: list[ResponseRawMessageAndToken] = []
|
||||
self.output_messages: list[ResponseRawMessageAndToken] = []
|
||||
self._accumulated_token_ids: list[int] = []
|
||||
self.kv_transfer_params: dict[str, Any] | None = None
|
||||
|
||||
def append_output(self, output: RequestOutput) -> None:
|
||||
self.num_prompt_tokens = len(output.prompt_token_ids or [])
|
||||
self.num_cached_tokens = output.num_cached_tokens or 0
|
||||
self.num_output_tokens += len(output.outputs[0].token_ids or [])
|
||||
if output.kv_transfer_params is not None:
|
||||
self.kv_transfer_params = output.kv_transfer_params
|
||||
self.parser.process(output.outputs[0])
|
||||
output_token_ids = output.outputs[0].token_ids or []
|
||||
self._accumulated_token_ids.extend(output_token_ids)
|
||||
@@ -538,6 +544,7 @@ class HarmonyContext(ConversationContext):
|
||||
self.all_turn_metrics: list[TurnMetrics] = []
|
||||
self.is_first_turn = True
|
||||
self.first_tok_of_message = True # For streaming support
|
||||
self.kv_transfer_params: dict[str, Any] | None = None
|
||||
|
||||
def _update_num_reasoning_tokens(self):
|
||||
channel = self.parser.current_channel
|
||||
@@ -557,6 +564,8 @@ class HarmonyContext(ConversationContext):
|
||||
self._update_num_reasoning_tokens()
|
||||
self._update_prefill_token_usage(output)
|
||||
self._update_decode_token_usage(output)
|
||||
if output.kv_transfer_params is not None:
|
||||
self.kv_transfer_params = output.kv_transfer_params
|
||||
# Append current turn to all turn list for next turn's calculations
|
||||
self.all_turn_metrics.append(self.current_turn_metrics.copy())
|
||||
self.current_turn_metrics.reset()
|
||||
@@ -868,6 +877,8 @@ class StreamingHarmonyContext(HarmonyContext):
|
||||
if last_delta_text:
|
||||
self.last_content_delta = last_delta_text
|
||||
self._update_decode_token_usage(output)
|
||||
if output.kv_transfer_params is not None:
|
||||
self.kv_transfer_params = output.kv_transfer_params
|
||||
|
||||
# For streaming, update previous turn when message is complete
|
||||
if output.finished:
|
||||
|
||||
@@ -252,6 +252,10 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
"numeric values, used by custom extensions."
|
||||
),
|
||||
)
|
||||
kv_transfer_params: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="KVTransfer parameters used for disaggregated serving.",
|
||||
)
|
||||
# --8<-- [end:responses-extra-params]
|
||||
|
||||
def build_chat_params(
|
||||
@@ -351,6 +355,10 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
if isinstance(stop, str):
|
||||
stop = [stop]
|
||||
|
||||
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
|
||||
if self.kv_transfer_params:
|
||||
extra_args["kv_transfer_params"] = self.kv_transfer_params
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
@@ -367,7 +375,7 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
),
|
||||
structured_outputs=structured_outputs,
|
||||
logit_bias=self.logit_bias,
|
||||
extra_args=self.vllm_xargs or {},
|
||||
extra_args=extra_args,
|
||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
@@ -488,6 +496,11 @@ class ResponsesResponse(OpenAIBaseModel):
|
||||
usage: ResponseUsage | None = None
|
||||
user: str | None = None
|
||||
|
||||
# vLLM-specific fields that are not in OpenAI spec
|
||||
kv_transfer_params: dict[str, Any] | None = Field(
|
||||
default=None, description="KVTransfer parameters."
|
||||
)
|
||||
|
||||
# --8<-- [start:responses-response-extra-params]
|
||||
# These are populated when enable_response_messages is set to True
|
||||
# NOTE: custom serialization is needed
|
||||
@@ -531,6 +544,7 @@ class ResponsesResponse(OpenAIBaseModel):
|
||||
usage: ResponseUsage | None = None,
|
||||
input_messages: ResponseInputOutputMessage | None = None,
|
||||
output_messages: ResponseInputOutputMessage | None = None,
|
||||
kv_transfer_params: dict[str, Any] | None = None,
|
||||
) -> "ResponsesResponse":
|
||||
incomplete_details: IncompleteDetails | None = None
|
||||
if status == "incomplete":
|
||||
@@ -566,6 +580,7 @@ class ResponsesResponse(OpenAIBaseModel):
|
||||
truncation=request.truncation,
|
||||
user=request.user,
|
||||
usage=usage,
|
||||
kv_transfer_params=kv_transfer_params,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -873,6 +873,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
output=output,
|
||||
status=status,
|
||||
usage=usage,
|
||||
kv_transfer_params=context.kv_transfer_params,
|
||||
)
|
||||
|
||||
if request.store:
|
||||
|
||||
Reference in New Issue
Block a user