From 17ee641c4510c93d8d2b826b19daa9f86126894e Mon Sep 17 00:00:00 2001 From: Bongwoo Bak Date: Sat, 21 Mar 2026 14:48:54 +0900 Subject: [PATCH] [Responses API] Add kv_transfer_params for PD disaggregation (#37424) Signed-off-by: bongwoobak Co-authored-by: Chauncey --- vllm/entrypoints/openai/responses/context.py | 13 ++++++++++++- vllm/entrypoints/openai/responses/protocol.py | 17 ++++++++++++++++- vllm/entrypoints/openai/responses/serving.py | 1 + 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/responses/context.py b/vllm/entrypoints/openai/responses/context.py index bab59e0aa..a4c55c23c 100644 --- a/vllm/entrypoints/openai/responses/context.py +++ b/vllm/entrypoints/openai/responses/context.py @@ -9,7 +9,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable from contextlib import AsyncExitStack from dataclasses import replace -from typing import TYPE_CHECKING, Final, Union +from typing import TYPE_CHECKING, Any, Final, Union from openai.types.responses.response_function_tool_call_output_item import ( ResponseFunctionToolCallOutputItem, @@ -182,6 +182,7 @@ class SimpleContext(ConversationContext): self.all_turn_metrics = [] self.input_messages: list[ResponseRawMessageAndToken] = [] + self.kv_transfer_params: dict[str, Any] | None = None def append_output(self, output) -> None: self.last_output = output @@ -190,6 +191,8 @@ class SimpleContext(ConversationContext): self.num_prompt_tokens = len(output.prompt_token_ids or []) self.num_cached_tokens = output.num_cached_tokens or 0 self.num_output_tokens += len(output.outputs[0].token_ids or []) + if output.kv_transfer_params is not None: + self.kv_transfer_params = output.kv_transfer_params # Accumulate text, token_ids, and logprobs for streaming mode delta_output = output.outputs[0] @@ -308,11 +311,14 @@ class ParsableContext(ConversationContext): self.input_messages: list[ResponseRawMessageAndToken] = [] self.output_messages: list[ResponseRawMessageAndToken] = [] self._accumulated_token_ids: list[int] = [] + self.kv_transfer_params: dict[str, Any] | None = None def append_output(self, output: RequestOutput) -> None: self.num_prompt_tokens = len(output.prompt_token_ids or []) self.num_cached_tokens = output.num_cached_tokens or 0 self.num_output_tokens += len(output.outputs[0].token_ids or []) + if output.kv_transfer_params is not None: + self.kv_transfer_params = output.kv_transfer_params self.parser.process(output.outputs[0]) output_token_ids = output.outputs[0].token_ids or [] self._accumulated_token_ids.extend(output_token_ids) @@ -538,6 +544,7 @@ class HarmonyContext(ConversationContext): self.all_turn_metrics: list[TurnMetrics] = [] self.is_first_turn = True self.first_tok_of_message = True # For streaming support + self.kv_transfer_params: dict[str, Any] | None = None def _update_num_reasoning_tokens(self): channel = self.parser.current_channel @@ -557,6 +564,8 @@ class HarmonyContext(ConversationContext): self._update_num_reasoning_tokens() self._update_prefill_token_usage(output) self._update_decode_token_usage(output) + if output.kv_transfer_params is not None: + self.kv_transfer_params = output.kv_transfer_params # Append current turn to all turn list for next turn's calculations self.all_turn_metrics.append(self.current_turn_metrics.copy()) self.current_turn_metrics.reset() @@ -868,6 +877,8 @@ class StreamingHarmonyContext(HarmonyContext): if last_delta_text: self.last_content_delta = last_delta_text self._update_decode_token_usage(output) + if output.kv_transfer_params is not None: + self.kv_transfer_params = output.kv_transfer_params # For streaming, update previous turn when message is complete if output.finished: diff --git a/vllm/entrypoints/openai/responses/protocol.py b/vllm/entrypoints/openai/responses/protocol.py index a5f62bdd8..43fbba1dd 100644 --- a/vllm/entrypoints/openai/responses/protocol.py +++ b/vllm/entrypoints/openai/responses/protocol.py @@ -252,6 +252,10 @@ class ResponsesRequest(OpenAIBaseModel): "numeric values, used by custom extensions." ), ) + kv_transfer_params: dict[str, Any] | None = Field( + default=None, + description="KVTransfer parameters used for disaggregated serving.", + ) # --8<-- [end:responses-extra-params] def build_chat_params( @@ -351,6 +355,10 @@ class ResponsesRequest(OpenAIBaseModel): if isinstance(stop, str): stop = [stop] + extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {} + if self.kv_transfer_params: + extra_args["kv_transfer_params"] = self.kv_transfer_params + return SamplingParams.from_optional( temperature=temperature, top_p=top_p, @@ -367,7 +375,7 @@ class ResponsesRequest(OpenAIBaseModel): ), structured_outputs=structured_outputs, logit_bias=self.logit_bias, - extra_args=self.vllm_xargs or {}, + extra_args=extra_args, skip_clone=True, # Created fresh per request, safe to skip clone skip_special_tokens=self.skip_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, @@ -488,6 +496,11 @@ class ResponsesResponse(OpenAIBaseModel): usage: ResponseUsage | None = None user: str | None = None + # vLLM-specific fields that are not in OpenAI spec + kv_transfer_params: dict[str, Any] | None = Field( + default=None, description="KVTransfer parameters." + ) + # --8<-- [start:responses-response-extra-params] # These are populated when enable_response_messages is set to True # NOTE: custom serialization is needed @@ -531,6 +544,7 @@ class ResponsesResponse(OpenAIBaseModel): usage: ResponseUsage | None = None, input_messages: ResponseInputOutputMessage | None = None, output_messages: ResponseInputOutputMessage | None = None, + kv_transfer_params: dict[str, Any] | None = None, ) -> "ResponsesResponse": incomplete_details: IncompleteDetails | None = None if status == "incomplete": @@ -566,6 +580,7 @@ class ResponsesResponse(OpenAIBaseModel): truncation=request.truncation, user=request.user, usage=usage, + kv_transfer_params=kv_transfer_params, ) diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index 574282c4c..53c28693a 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -873,6 +873,7 @@ class OpenAIServingResponses(OpenAIServing): output=output, status=status, usage=usage, + kv_transfer_params=context.kv_transfer_params, ) if request.store: