[Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. (#3951)
This commit is contained in:
@@ -1,11 +1,42 @@
|
||||
"""The tests in this file verify end-to-end speculative decoding correctness.
|
||||
|
||||
This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality. This gives us good coverage of temp=0.
|
||||
|
||||
For temp>0, we rely on unit tests on the rejection sampler to verify that the
|
||||
output distribution is the same with spec decode vs. no spec decode (this would
|
||||
be prohibitively expensive to run with a real model).
|
||||
|
||||
NOTE: Speculative decoding's distribution equality requires that the measured
|
||||
distributions of the target model and proposal model be deterministic given the
|
||||
same input. vLLM largely guarantees this.
|
||||
|
||||
@cadedaniel has seen cases where the output probabilities of a draft/target
|
||||
model change slightly with certain batch sizes or prompts, even with Torch
|
||||
determinism flags set. It is unclear if this is a bug in vLLM, due to non-
|
||||
determinism in on-device batched operations, a bug in vLLM's spec decode
|
||||
implementation, or the "hardware numerics" limitations. Either way, rejection
|
||||
sampling ensures the output distribution matches the target model, but it breaks
|
||||
greedy-equality tests for those batch sizes/prompts.
|
||||
"""
|
||||
|
||||
from itertools import cycle
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from .conftest import get_output_from_llm_generator
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
@@ -14,9 +45,6 @@ from vllm import SamplingParams
|
||||
# Note this is repeated in the test body; to initialize a tokenizer.
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip real loading for fast test.
|
||||
"load_format": "dummy",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
@@ -31,22 +59,15 @@ from vllm import SamplingParams
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 1,
|
||||
},
|
||||
{
|
||||
# No spec decode.
|
||||
# Verify the detokenizer assertions in the test work when spec
|
||||
# decode is disabled.
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
# NOTE: We should run more permutations of this test (more BS, more seeds). But
|
||||
# because our spec decode generates gibberish token ids, the likelihood of
|
||||
# emitting an invalid token combination is nontrivial. This causes divergence in
|
||||
# behavior of vLLM detokenization vs. hf tokenizer, for example when two "utf-
|
||||
# start" bytes are emitted.
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
|
||||
def test_spec_decode_e2e_with_detokenization(test_llm_generator,
|
||||
batch_size: int):
|
||||
"""Run generation with speculative decoding on a batch. Verify the engine
|
||||
generates the correct number of tokens (via ignore_eos=True), and that the
|
||||
detokenization matches HF transformers.
|
||||
@@ -67,8 +88,6 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
skip_special_tokens=True,
|
||||
spaces_between_special_tokens=False,
|
||||
)
|
||||
|
||||
batch_tokens, batch_token_ids = get_output_from_llm_generator(
|
||||
@@ -77,9 +96,10 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
|
||||
# Expect a generation for each prompt in the batch.
|
||||
assert len(batch_token_ids) == len(prompts)
|
||||
|
||||
# Expect each generation to have expected number of tokens (note
|
||||
# ignore_eos=True).
|
||||
assert all(len(token_ids) == output_len for token_ids in batch_token_ids)
|
||||
# Expect each generation to have expected number of tokens (note ignore_eos
|
||||
# is True).
|
||||
assert [len(token_ids)
|
||||
for token_ids in batch_token_ids] == ([output_len] * batch_size)
|
||||
|
||||
# Expect detokenized string to match.
|
||||
tok = AutoTokenizer.from_pretrained("JackFram/llama-68m")
|
||||
@@ -92,13 +112,293 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
"model": "JackFram/llama-68m",
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
# Try two different tiny base models.
|
||||
# Note that one is equal to the draft model, another isn't.
|
||||
{
|
||||
"model": "JackFram/llama-68m",
|
||||
},
|
||||
{
|
||||
"model": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use long output len for the small model test.
|
||||
1536,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality on a tiny model with batch size of one.
|
||||
|
||||
# Skip real loading for fast test.
|
||||
"load_format": "dummy",
|
||||
Since this test is cheaper than other e2e correctness tests, we generate
|
||||
with a higher output_len.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
# Try two different tiny base models.
|
||||
# Note that one is equal to the draft model, another isn't.
|
||||
{
|
||||
"model": "JackFram/llama-68m",
|
||||
},
|
||||
{
|
||||
"model": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [64])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality on a tiny model and large batch size.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
# Try two different tiny base models.
|
||||
# Note that one is equal to the draft model, another isn't.
|
||||
{
|
||||
"model": "JackFram/llama-68m",
|
||||
},
|
||||
{
|
||||
"model": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("max_output_len", [
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
max_output_len: int):
|
||||
"""Verify greedy equality on a tiny model, with a large batch size, and when
|
||||
sampling respects the EOS token.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# A "real" model (not tiny).
|
||||
"model": "meta-llama/Llama-2-7b-chat-hf",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use decently long output len for a high quality test.
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality on a "real" model and batch size of 1. This is
|
||||
separate from large BS tests to make identifying the source of bugs easier.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# A "real" model (not tiny).
|
||||
"model": "meta-llama/Llama-2-7b-chat-hf",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [32])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
64,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality with a "real" model on a nontrivial batch size.
|
||||
This is the closest test to a real production workload.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 8,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_greedy_correctness_with_preemption(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
@@ -109,43 +409,189 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
# As of this writing, vLLM only compiles with these 3 block sizes by
|
||||
# default.
|
||||
{
|
||||
# Expect failure as spec decode not supported by
|
||||
# Ray backend.
|
||||
"worker_use_ray": True,
|
||||
"block_size": 8,
|
||||
},
|
||||
{
|
||||
"block_size": 16,
|
||||
},
|
||||
{
|
||||
"block_size": 32,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_xfail(test_llm_generator):
|
||||
"""Verify that speculative decoding with Ray fails.
|
||||
def test_spec_decode_different_block_size(baseline_llm_generator,
|
||||
test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality over different block sizes.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
|
||||
# Artificially limit the draft model max model len; this forces vLLM
|
||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||
"speculative_max_model_len": 32,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# This must be a good bit larger than speculative_max_model_len so that
|
||||
# we can test the case where all seqs are skipped, but still small to
|
||||
# ensure fast test.
|
||||
64,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_skip_speculation(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
"""Verify greedy equality when some (or all) sequences skip speculation.
|
||||
We do this by setting the max model len of the draft model to an
|
||||
artificially low value, such that when the sequences grow beyond it, they
|
||||
are skipped in speculative decoding.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify that speculative decoding produces exact equality to without spec
|
||||
decode with many different values of k.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len: bool,
|
||||
print_tokens: bool = False):
|
||||
"""Helper method that compares the outputs of both the baseline LLM and
|
||||
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||
the same when temperature is zero.
|
||||
"""
|
||||
output_len = 128
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"San Francisco is know for its",
|
||||
"Facebook was created in 2004 by",
|
||||
"Curious George is a",
|
||||
"Python 3.11 brings improvements to its",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
# If the test requires that we generated max_output_len tokens, then set the
|
||||
# sampling params to ignore eos token.
|
||||
ignore_eos = force_output_len
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
max_tokens=max_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
with pytest.raises(AssertionError,
|
||||
match="Speculative decoding not yet supported for "):
|
||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
|
||||
test_llm_generator, prompts, sampling_params)
|
||||
|
||||
(baseline_batch_tokens,
|
||||
baseline_batch_token_ids) = get_output_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
def get_output_from_llm_generator(
|
||||
llm_generator, prompts,
|
||||
sampling_params) -> Tuple[List[str], List[List[int]]]:
|
||||
for llm in llm_generator:
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
token_ids = [output.outputs[0].token_ids for output in outputs]
|
||||
tokens = [output.outputs[0].text for output in outputs]
|
||||
del llm
|
||||
assert len(baseline_batch_token_ids) == len(prompts)
|
||||
assert len(spec_batch_token_ids) == len(prompts)
|
||||
|
||||
return tokens, token_ids
|
||||
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
|
||||
spec_tokens) in enumerate(
|
||||
zip(baseline_batch_token_ids, baseline_batch_tokens,
|
||||
spec_batch_token_ids, spec_batch_tokens)):
|
||||
if print_tokens:
|
||||
print(f'{i=} {baseline_tokens=}')
|
||||
print(f'{i=} {spec_tokens=}')
|
||||
print(f'{i=} {baseline_token_ids=}')
|
||||
print(f'{i=} {spec_token_ids=}')
|
||||
assert baseline_token_ids == spec_token_ids
|
||||
|
||||
Reference in New Issue
Block a user