[2/N] Chunked prefill data update (#3538)

This commit is contained in:
SangBin Cho
2024-03-29 02:06:01 +09:00
committed by GitHub
parent ce567a2926
commit b51c1cc9d2
11 changed files with 272 additions and 76 deletions

View File

@@ -1,6 +1,7 @@
import pytest
from vllm.sequence import SamplerOutput, SequenceGroupOutput, SequenceOutput
from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput,
SequenceOutput)
@pytest.fixture
@@ -48,3 +49,24 @@ def test_sampler_output_eq(sample_outputs):
sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1])
assert sampler_output1 == sampler_output2
assert sampler_output1 != sampler_output3
def test_sequence_data_prefill():
seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4])
assert seq_data.get_num_uncomputed_tokens() == 4
assert seq_data.get_num_computed_tokens() == 0
# advance by 2
seq_data.update_num_computed_tokens(2)
assert seq_data.get_num_uncomputed_tokens() == 2
assert seq_data.get_num_computed_tokens() == 2
# advance by 1
seq_data.update_num_computed_tokens(1)
assert seq_data.get_num_uncomputed_tokens() == 1
assert seq_data.get_num_computed_tokens() == 3
# append tokens and reset, simulating recompute
seq_data.append_token_id(1, logprob=0.0)
seq_data.reset_num_computed_tokens()
assert seq_data.get_num_uncomputed_tokens() == 5
assert seq_data.get_num_computed_tokens() == 0