[Bugfix] fix IntermediateTensors equal method (#23027)

Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
Ning Xie
2025-08-18 17:58:11 +08:00
committed by GitHub
parent 27e8d1ea3e
commit 5a30bd10d8
2 changed files with 45 additions and 3 deletions

View File

@@ -2,10 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (CompletionSequenceGroupOutput, SequenceData,
SequenceOutput)
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
SequenceData, SequenceOutput)
from .core.utils import create_dummy_prompt
@@ -98,3 +99,38 @@ def test_sequence_group_stage():
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(1)
assert seq_group.is_prefill() is False
def test_sequence_intermediate_tensors_equal():
class AnotherIntermediateTensors(IntermediateTensors):
pass
intermediate_tensors = IntermediateTensors({})
another_intermediate_tensors = AnotherIntermediateTensors({})
assert intermediate_tensors != another_intermediate_tensors
empty_intermediate_tensors_1 = IntermediateTensors({})
empty_intermediate_tensors_2 = IntermediateTensors({})
assert empty_intermediate_tensors_1 == empty_intermediate_tensors_2
different_key_intermediate_tensors_1 = IntermediateTensors(
{"1": torch.zeros([2, 4], dtype=torch.int32)})
difference_key_intermediate_tensors_2 = IntermediateTensors(
{"2": torch.zeros([2, 4], dtype=torch.int32)})
assert (different_key_intermediate_tensors_1
!= difference_key_intermediate_tensors_2)
same_key_different_value_intermediate_tensors_1 = IntermediateTensors(
{"1": torch.zeros([2, 4], dtype=torch.int32)})
same_key_different_value_intermediate_tensors_2 = IntermediateTensors(
{"1": torch.zeros([2, 5], dtype=torch.int32)})
assert (same_key_different_value_intermediate_tensors_1
!= same_key_different_value_intermediate_tensors_2)
same_key_same_value_intermediate_tensors_1 = IntermediateTensors(
{"1": torch.zeros([2, 4], dtype=torch.int32)})
same_key_same_value_intermediate_tensors_2 = IntermediateTensors(
{"1": torch.zeros([2, 4], dtype=torch.int32)})
assert (same_key_same_value_intermediate_tensors_1 ==
same_key_same_value_intermediate_tensors_2)