[Misc] Simplify PoolerOutput and move to v1/outputs (#25629)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -11,7 +11,6 @@ if TYPE_CHECKING:
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorOutput)
|
||||
else:
|
||||
LoRARequest = Any
|
||||
KVConnectorOutput = Any
|
||||
|
||||
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
|
||||
@@ -48,29 +47,6 @@ class RequestMetrics:
|
||||
model_execute_time: Optional[float] = None
|
||||
|
||||
|
||||
class PoolingSequenceGroupOutput(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
):
|
||||
"""The model output associated with a pooling sequence group."""
|
||||
# Annotated as Any to be compatible with msgspec
|
||||
# The actual type is in SequenceGroup.pooled_data
|
||||
data: Any
|
||||
|
||||
def get_data_nbytes(self) -> int:
|
||||
data: torch.Tensor = self.data
|
||||
return data.nbytes
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"PoolingSequenceGroupOutput(data={self.data}"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, PoolingSequenceGroupOutput):
|
||||
raise NotImplementedError()
|
||||
return self.data == other.data
|
||||
|
||||
|
||||
# cannot use msgspec.Struct here because Dynamo does not support it
|
||||
@dataclass
|
||||
class IntermediateTensors:
|
||||
@@ -119,30 +95,6 @@ class IntermediateTensors:
|
||||
return f"IntermediateTensors(tensors={self.tensors})"
|
||||
|
||||
|
||||
class PoolerOutput(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""The output from a pooling operation in the pooling model."""
|
||||
outputs: list[PoolingSequenceGroupOutput]
|
||||
|
||||
def get_data_nbytes(self) -> int:
|
||||
return sum(o.get_data_nbytes() for o in self.outputs)
|
||||
|
||||
def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
|
||||
return self.outputs[idx]
|
||||
|
||||
def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
|
||||
self.outputs[idx] = value
|
||||
|
||||
def __len__(self):
|
||||
return len(self.outputs)
|
||||
|
||||
def __eq__(self, other: object):
|
||||
return isinstance(other,
|
||||
self.__class__) and self.outputs == other.outputs
|
||||
|
||||
|
||||
class ExecuteModelRequest(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
|
||||
Reference in New Issue
Block a user