[KVConnector] Remove redundant method KVConnectorOutput::merge() (#38546)
Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
This commit is contained in:
@@ -2,9 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypeVar
|
||||
from typing import TYPE_CHECKING, NamedTuple, TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -125,20 +124,6 @@ class SamplerOutput:
|
||||
logprobs_tensors: LogprobsTensors | None
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _combine_non_none(f: Callable[[T, T], T], items: list[T | None]) -> T | None:
|
||||
non_none = [item for item in items if item is not None]
|
||||
if len(non_none) == 0:
|
||||
return None
|
||||
|
||||
combined = non_none[0]
|
||||
for item in non_none[1:]:
|
||||
combined = f(combined, item)
|
||||
return combined
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVConnectorOutput:
|
||||
# [req_ids]
|
||||
@@ -167,43 +152,6 @@ class KVConnectorOutput:
|
||||
and not self.kv_connector_worker_meta
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def merge(cls, *outputs: "KVConnectorOutput"):
|
||||
assert len(outputs) > 0, "Cannot merge empty outputs"
|
||||
finished_sending = _combine_non_none(
|
||||
set.union, [output.finished_sending for output in outputs]
|
||||
)
|
||||
finished_recving = _combine_non_none(
|
||||
set.union, [output.finished_recving for output in outputs]
|
||||
)
|
||||
kv_connector_stats = _combine_non_none(
|
||||
lambda x, y: x.aggregate(y),
|
||||
[output.kv_connector_stats for output in outputs],
|
||||
)
|
||||
kv_cache_events = _combine_non_none(
|
||||
lambda x, y: x.merge(y),
|
||||
[output.kv_cache_events for output in outputs],
|
||||
)
|
||||
invalid_block_ids = _combine_non_none(
|
||||
set.union, [output.invalid_block_ids for output in outputs]
|
||||
)
|
||||
assert invalid_block_ids is not None
|
||||
|
||||
assert all(
|
||||
output.expected_finished_count == outputs[0].expected_finished_count
|
||||
for output in outputs
|
||||
)
|
||||
expected_finished_count = outputs[0].expected_finished_count
|
||||
|
||||
return cls(
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
kv_connector_stats=kv_connector_stats,
|
||||
kv_cache_events=kv_cache_events,
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
expected_finished_count=expected_finished_count,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ECConnectorOutput:
|
||||
|
||||
Reference in New Issue
Block a user