[Responses API] Add kv_transfer_params for PD disaggregation (#37424)

Signed-off-by: bongwoobak <bongwoobak@gmail.com>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
Bongwoo Bak
2026-03-21 14:48:54 +09:00
committed by GitHub
parent 0d50fa1db6
commit 17ee641c45
3 changed files with 29 additions and 2 deletions

View File

@@ -9,7 +9,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable
from contextlib import AsyncExitStack
from dataclasses import replace
from typing import TYPE_CHECKING, Final, Union
from typing import TYPE_CHECKING, Any, Final, Union
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
@@ -182,6 +182,7 @@ class SimpleContext(ConversationContext):
self.all_turn_metrics = []
self.input_messages: list[ResponseRawMessageAndToken] = []
self.kv_transfer_params: dict[str, Any] | None = None
def append_output(self, output) -> None:
self.last_output = output
@@ -190,6 +191,8 @@ class SimpleContext(ConversationContext):
self.num_prompt_tokens = len(output.prompt_token_ids or [])
self.num_cached_tokens = output.num_cached_tokens or 0
self.num_output_tokens += len(output.outputs[0].token_ids or [])
if output.kv_transfer_params is not None:
self.kv_transfer_params = output.kv_transfer_params
# Accumulate text, token_ids, and logprobs for streaming mode
delta_output = output.outputs[0]
@@ -308,11 +311,14 @@ class ParsableContext(ConversationContext):
self.input_messages: list[ResponseRawMessageAndToken] = []
self.output_messages: list[ResponseRawMessageAndToken] = []
self._accumulated_token_ids: list[int] = []
self.kv_transfer_params: dict[str, Any] | None = None
def append_output(self, output: RequestOutput) -> None:
self.num_prompt_tokens = len(output.prompt_token_ids or [])
self.num_cached_tokens = output.num_cached_tokens or 0
self.num_output_tokens += len(output.outputs[0].token_ids or [])
if output.kv_transfer_params is not None:
self.kv_transfer_params = output.kv_transfer_params
self.parser.process(output.outputs[0])
output_token_ids = output.outputs[0].token_ids or []
self._accumulated_token_ids.extend(output_token_ids)
@@ -538,6 +544,7 @@ class HarmonyContext(ConversationContext):
self.all_turn_metrics: list[TurnMetrics] = []
self.is_first_turn = True
self.first_tok_of_message = True # For streaming support
self.kv_transfer_params: dict[str, Any] | None = None
def _update_num_reasoning_tokens(self):
channel = self.parser.current_channel
@@ -557,6 +564,8 @@ class HarmonyContext(ConversationContext):
self._update_num_reasoning_tokens()
self._update_prefill_token_usage(output)
self._update_decode_token_usage(output)
if output.kv_transfer_params is not None:
self.kv_transfer_params = output.kv_transfer_params
# Append current turn to all turn list for next turn's calculations
self.all_turn_metrics.append(self.current_turn_metrics.copy())
self.current_turn_metrics.reset()
@@ -868,6 +877,8 @@ class StreamingHarmonyContext(HarmonyContext):
if last_delta_text:
self.last_content_delta = last_delta_text
self._update_decode_token_usage(output)
if output.kv_transfer_params is not None:
self.kv_transfer_params = output.kv_transfer_params
# For streaming, update previous turn when message is complete
if output.finished:

View File

@@ -252,6 +252,10 @@ class ResponsesRequest(OpenAIBaseModel):
"numeric values, used by custom extensions."
),
)
kv_transfer_params: dict[str, Any] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
# --8<-- [end:responses-extra-params]
def build_chat_params(
@@ -351,6 +355,10 @@ class ResponsesRequest(OpenAIBaseModel):
if isinstance(stop, str):
stop = [stop]
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:
extra_args["kv_transfer_params"] = self.kv_transfer_params
return SamplingParams.from_optional(
temperature=temperature,
top_p=top_p,
@@ -367,7 +375,7 @@ class ResponsesRequest(OpenAIBaseModel):
),
structured_outputs=structured_outputs,
logit_bias=self.logit_bias,
extra_args=self.vllm_xargs or {},
extra_args=extra_args,
skip_clone=True, # Created fresh per request, safe to skip clone
skip_special_tokens=self.skip_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
@@ -488,6 +496,11 @@ class ResponsesResponse(OpenAIBaseModel):
usage: ResponseUsage | None = None
user: str | None = None
# vLLM-specific fields that are not in OpenAI spec
kv_transfer_params: dict[str, Any] | None = Field(
default=None, description="KVTransfer parameters."
)
# --8<-- [start:responses-response-extra-params]
# These are populated when enable_response_messages is set to True
# NOTE: custom serialization is needed
@@ -531,6 +544,7 @@ class ResponsesResponse(OpenAIBaseModel):
usage: ResponseUsage | None = None,
input_messages: ResponseInputOutputMessage | None = None,
output_messages: ResponseInputOutputMessage | None = None,
kv_transfer_params: dict[str, Any] | None = None,
) -> "ResponsesResponse":
incomplete_details: IncompleteDetails | None = None
if status == "incomplete":
@@ -566,6 +580,7 @@ class ResponsesResponse(OpenAIBaseModel):
truncation=request.truncation,
user=request.user,
usage=usage,
kv_transfer_params=kv_transfer_params,
)

View File

@@ -873,6 +873,7 @@ class OpenAIServingResponses(OpenAIServing):
output=output,
status=status,
usage=usage,
kv_transfer_params=context.kv_transfer_params,
)
if request.store: