[Speculative decoding 3/9] Worker which speculates, scores, and applies rejection sampling (#3103)
This commit is contained in:
50
tests/test_sequence.py
Normal file
50
tests/test_sequence.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import pytest
|
||||
|
||||
from vllm.sequence import SequenceGroupOutput, SamplerOutput, SequenceOutput
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_outputs():
|
||||
return [
|
||||
SequenceGroupOutput(samples=[
|
||||
SequenceOutput(parent_seq_id=0, output_token=i, logprobs={})
|
||||
],
|
||||
prompt_logprobs=None) for i in range(5)
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sampler_output(sample_outputs):
|
||||
return SamplerOutput(outputs=sample_outputs)
|
||||
|
||||
|
||||
def test_sampler_output_initialization(sampler_output, sample_outputs):
|
||||
assert len(sampler_output) == len(sample_outputs)
|
||||
assert sampler_output.sampled_token_probs is None
|
||||
assert sampler_output.sampled_token_ids is None
|
||||
assert sampler_output.spec_decode_worker_metrics is None
|
||||
|
||||
|
||||
def test_sampler_output_getitem(sampler_output, sample_outputs):
|
||||
assert sampler_output[2] == sample_outputs[2]
|
||||
|
||||
|
||||
def test_sampler_output_setitem(sampler_output):
|
||||
new_output = SequenceGroupOutput(samples=[
|
||||
SequenceOutput(parent_seq_id=0, output_token=99, logprobs={})
|
||||
],
|
||||
prompt_logprobs=None)
|
||||
sampler_output[2] = new_output
|
||||
assert sampler_output[2] == new_output
|
||||
|
||||
|
||||
def test_sampler_output_len(sampler_output, sample_outputs):
|
||||
assert len(sampler_output) == len(sample_outputs)
|
||||
|
||||
|
||||
def test_sampler_output_eq(sample_outputs):
|
||||
sampler_output1 = SamplerOutput(outputs=sample_outputs)
|
||||
sampler_output2 = SamplerOutput(outputs=sample_outputs.copy())
|
||||
sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1])
|
||||
assert sampler_output1 == sampler_output2
|
||||
assert sampler_output1 != sampler_output3
|
||||
Reference in New Issue
Block a user