[2/N] Chunked prefill data update (#3538)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupOutput, SequenceOutput
|
||||
from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput,
|
||||
SequenceOutput)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -48,3 +49,24 @@ def test_sampler_output_eq(sample_outputs):
|
||||
sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1])
|
||||
assert sampler_output1 == sampler_output2
|
||||
assert sampler_output1 != sampler_output3
|
||||
|
||||
|
||||
def test_sequence_data_prefill():
|
||||
seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4])
|
||||
assert seq_data.get_num_uncomputed_tokens() == 4
|
||||
assert seq_data.get_num_computed_tokens() == 0
|
||||
# advance by 2
|
||||
seq_data.update_num_computed_tokens(2)
|
||||
assert seq_data.get_num_uncomputed_tokens() == 2
|
||||
assert seq_data.get_num_computed_tokens() == 2
|
||||
|
||||
# advance by 1
|
||||
seq_data.update_num_computed_tokens(1)
|
||||
assert seq_data.get_num_uncomputed_tokens() == 1
|
||||
assert seq_data.get_num_computed_tokens() == 3
|
||||
|
||||
# append tokens and reset, simulating recompute
|
||||
seq_data.append_token_id(1, logprob=0.0)
|
||||
seq_data.reset_num_computed_tokens()
|
||||
assert seq_data.get_num_uncomputed_tokens() == 5
|
||||
assert seq_data.get_num_computed_tokens() == 0
|
||||
|
||||
Reference in New Issue
Block a user