[Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker (#5348)

This commit is contained in:
sroy745
2024-07-01 00:33:05 -07:00
committed by GitHub
parent 614aa51203
commit 80ca1e6a3a
14 changed files with 480 additions and 208 deletions

View File

@@ -11,9 +11,15 @@ distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality. This gives us good coverage of temp=0.
At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the
highest probability in the target distribution are accepted. Therefore, we can
expect greedy equality for the TypicalAcceptanceSampler at temp=0.
For temp>0, we rely on unit tests on the rejection sampler to verify that the
output distribution is the same with spec decode vs. no spec decode (this would
be prohibitively expensive to run with a real model).
be prohibitively expensive to run with a real model). Similarly, for the
TypicalAcceptance sampler also, we rely on unit tests to validate temp>0
test cases.
NOTE: Speculative decoding's distribution equality requires that the measured
distributions of the target model and proposal model be deterministic given the
@@ -611,3 +617,49 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler"
}
# Try a range of common k.
for k in [1, 2, 3]
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_typical_acceptance_sampling(baseline_llm_generator,
test_llm_generator, batch_size: int,
output_len: int):
"""Verify that speculative decoding produces exact equality to without spec
decode with TypicalAcceptanceSampler as the draft token acceptance
sampling method.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)