[V1] [P/D] Refactor KV Connector Path (#21980)

Signed-off-by: David Ben-David <davidb@pliops.com>
Co-authored-by: David Ben-David <davidb@pliops.com>
This commit is contained in:
David Ben-David
2025-08-03 14:03:40 +03:00
committed by GitHub
parent 24d1dffbeb
commit aefeea0fde
12 changed files with 142 additions and 80 deletions

View File

@@ -10,7 +10,7 @@ from collections.abc import Mapping
from collections.abc import Sequence as GenericSequence
from dataclasses import dataclass, field
from functools import reduce
from typing import Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import msgspec
import torch
@@ -21,6 +21,10 @@ from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
if TYPE_CHECKING:
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorOutput)
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
VLLM_INVALID_TOKEN_ID = -1
@@ -1159,14 +1163,11 @@ class IntermediateTensors:
states and residuals to be sent to the next stage. This data structure
contains the hidden states and residuals for a request.
Each stage also needs to handle its own finished_sending and
finished_recving in case of kv transfer.
Each stage also needs to handle its own kv_connector_output.
"""
tensors: dict[str, torch.Tensor]
# [req_ids]
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
kv_connector_output: Optional["KVConnectorOutput"]
def __init__(self, tensors):
# manually define this function, so that