[Tests] Add GSM8k check to SpecDec E2E tests (#34772)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
709eadbb0b
commit
ee59a7c615
@@ -8,6 +8,7 @@ from typing import Any
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.evals.gsm8k.gsm8k_eval import _build_gsm8k_prompts, evaluate_gsm8k_offline
|
||||
from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
||||
@@ -35,53 +36,57 @@ def _skip_if_insufficient_gpus_for_tp(tp_size: int):
|
||||
Messages = list[dict[str, Any]]
|
||||
|
||||
|
||||
def get_test_prompts(
|
||||
mm_enabled: bool, quiet: bool = False, num_prompts: int = 100
|
||||
) -> list[Messages]:
|
||||
prompt_types = ["repeat", "sentence"]
|
||||
def get_test_prompts(mm_enabled: bool, num_prompts: int = 100) -> list[Messages]:
|
||||
prompt_types = ["repeat", "gsm8k"]
|
||||
if mm_enabled:
|
||||
prompt_types.append("mm")
|
||||
prompts = []
|
||||
prompts: list[Messages] = []
|
||||
|
||||
random.seed(0)
|
||||
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
|
||||
|
||||
if not quiet:
|
||||
print(f"Prompt types: {random_prompt_type_choices}")
|
||||
num_repeat_prompts = num_prompts // len(prompt_types)
|
||||
if mm_enabled:
|
||||
num_gsm8k_prompts = num_prompts // len(prompt_types)
|
||||
num_mm_prompts = num_prompts - num_repeat_prompts - num_gsm8k_prompts
|
||||
else:
|
||||
num_mm_prompts = 0
|
||||
num_gsm8k_prompts = num_prompts - num_repeat_prompts
|
||||
|
||||
# Generate a mixed batch of prompts, some of which can be easily
|
||||
# predicted by n-gram matching and some which likely cannot.
|
||||
for kind in random_prompt_type_choices:
|
||||
random.seed(0)
|
||||
for _ in range(num_repeat_prompts):
|
||||
word_choices = ["test", "temp", "hello", "where"]
|
||||
word = random.choice(word_choices)
|
||||
prompt: str | list[dict[str, Any]] = ""
|
||||
if kind == "repeat":
|
||||
prompt = f"""
|
||||
please repeat the word '{word}' 10 times.
|
||||
give no other output than the word at least ten times in a row,
|
||||
in lowercase with spaces between each word and without quotes.
|
||||
"""
|
||||
elif kind == "sentence":
|
||||
prompt = f"""
|
||||
please give a ten-word sentence that
|
||||
uses the word {word} at least once.
|
||||
give no other output than that simple sentence without quotes.
|
||||
"""
|
||||
elif kind == "mm":
|
||||
placeholders = [
|
||||
prompts.append(
|
||||
[
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
|
||||
},
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
please repeat the word '{word}' 10 times.
|
||||
give no other output than the word at least ten times in a row,
|
||||
in lowercase with spaces between each word and without quotes.
|
||||
""",
|
||||
}
|
||||
]
|
||||
prompt = [
|
||||
*placeholders,
|
||||
{"type": "text", "text": "The meaning of the image is"},
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unknown prompt type: {kind}")
|
||||
)
|
||||
prompts.extend(
|
||||
[{"role": "user", "content": prompt}]
|
||||
for prompt in _build_gsm8k_prompts(
|
||||
num_questions=num_gsm8k_prompts, num_shots=5
|
||||
)[0]
|
||||
)
|
||||
for _ in range(num_mm_prompts):
|
||||
placeholders = [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
|
||||
},
|
||||
}
|
||||
]
|
||||
prompt = [
|
||||
*placeholders,
|
||||
{"type": "text", "text": "The meaning of the image is"},
|
||||
]
|
||||
prompts.append([{"role": "user", "content": prompt}])
|
||||
|
||||
return prompts
|
||||
@@ -113,6 +118,25 @@ def model_name():
|
||||
return "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
|
||||
def evaluate_llm_for_gsm8k(llm: LLM, expected_accuracy_threshold: float = 0.70) -> None:
|
||||
"""Evaluate the LLM on GSM8K and check that accuracy is above a sanity threshold.
|
||||
|
||||
The default threshold assumes the LLM uses the same target model as the "model_name"
|
||||
fixture, with max model len == 4096. Precomputed reference value is 75% to 80%
|
||||
on GSM8K with greedy decoding, so we check that it's above a sanity threshold of 70%
|
||||
to verify that the model is correct.
|
||||
"""
|
||||
if expected_accuracy_threshold <= 0.0:
|
||||
print("Skipping GSM8K evaluation")
|
||||
return
|
||||
results = evaluate_gsm8k_offline(llm)
|
||||
accuracy = results["accuracy"]
|
||||
print(f"GSM8K accuracy: {accuracy:.3f}")
|
||||
assert accuracy >= expected_accuracy_threshold, (
|
||||
f"Expected GSM8K accuracy >= {expected_accuracy_threshold}, got {accuracy:.3f}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_torch_dynamo():
|
||||
"""Reset torch dynamo cache before each test"""
|
||||
@@ -138,41 +162,14 @@ def reset_torch_dynamo():
|
||||
)
|
||||
def test_ngram_and_suffix_correctness(
|
||||
speculative_config: dict,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
"""
|
||||
Compare the outputs of an original LLM and a speculative LLM
|
||||
should be the same when using ngram speculative decoding.
|
||||
"""
|
||||
test_prompts = get_test_prompts(mm_enabled=False)
|
||||
|
||||
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
speculative_config=speculative_config,
|
||||
max_model_len=1024,
|
||||
max_model_len=4096,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
|
||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches >= int(0.66 * len(ref_outputs))
|
||||
evaluate_llm_for_gsm8k(spec_llm)
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
@@ -238,10 +235,10 @@ def test_suffix_decoding_acceptance(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_path",
|
||||
["model_path", "expected_accuracy_threshold"],
|
||||
[
|
||||
"RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3",
|
||||
"RedHatAI/Qwen3-8B-speculator.eagle3",
|
||||
("RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3", 0.7), # ref: 75%-80%
|
||||
("RedHatAI/Qwen3-8B-speculator.eagle3", 0.8), # ref: 87%-92%
|
||||
],
|
||||
ids=["llama3_eagle3_speculator", "qwen3_eagle3_speculator"],
|
||||
)
|
||||
@@ -249,6 +246,7 @@ def test_speculators_model_integration(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_path: str,
|
||||
expected_accuracy_threshold: float,
|
||||
):
|
||||
"""
|
||||
Test that speculators models work with the simplified integration.
|
||||
@@ -262,7 +260,8 @@ def test_speculators_model_integration(
|
||||
2. Verifier model is extracted from speculator config
|
||||
3. Speculative decoding is automatically enabled
|
||||
4. Text generation works correctly
|
||||
5. Output matches reference (non-speculative) generation
|
||||
5. GSM8k accuracy of the model passes a sanity check when speculative decoding on
|
||||
6. Output matches reference (non-speculative) generation
|
||||
"""
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
@@ -270,7 +269,10 @@ def test_speculators_model_integration(
|
||||
test_prompts = get_test_prompts(mm_enabled=False)
|
||||
|
||||
# First run: Direct speculator model (simplified integration)
|
||||
spec_llm = LLM(model=model_path, max_model_len=1024)
|
||||
spec_llm = LLM(model=model_path, max_model_len=4096)
|
||||
evaluate_llm_for_gsm8k(
|
||||
spec_llm, expected_accuracy_threshold=expected_accuracy_threshold
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
|
||||
# Verify speculative config was auto-detected
|
||||
@@ -297,7 +299,7 @@ def test_speculators_model_integration(
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
# Second run: Reference without speculative decoding
|
||||
ref_llm = LLM(model=verifier_model, max_model_len=1024)
|
||||
ref_llm = LLM(model=verifier_model, max_model_len=4096)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
@@ -318,19 +320,27 @@ def test_speculators_model_integration(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["model_setup", "mm_enabled", "enable_chunked_prefill", "model_impl"],
|
||||
[
|
||||
"model_setup",
|
||||
"mm_enabled",
|
||||
"enable_chunked_prefill",
|
||||
"model_impl",
|
||||
"expected_accuracy_threshold",
|
||||
],
|
||||
[
|
||||
(
|
||||
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.8, # ref: 90%
|
||||
),
|
||||
(
|
||||
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
|
||||
False,
|
||||
False,
|
||||
"transformers",
|
||||
0.8, # ref: 90%
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
@@ -342,6 +352,7 @@ def test_speculators_model_integration(
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.8, # ref: 90%
|
||||
marks=pytest.mark.skip(
|
||||
reason="architecture of its eagle3 is LlamaForCausalLMEagle3"
|
||||
),
|
||||
@@ -356,6 +367,7 @@ def test_speculators_model_integration(
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.7, # TODO, update this with a reference value when re-enabling this case
|
||||
marks=pytest.mark.skip(
|
||||
reason="Skipping due to its head_dim not being a a multiple of 32"
|
||||
),
|
||||
@@ -370,6 +382,7 @@ def test_speculators_model_integration(
|
||||
False,
|
||||
True,
|
||||
"auto",
|
||||
0.7, # ref: 75%-80%
|
||||
marks=large_gpu_mark(min_gb=40),
|
||||
), # works on 4x H100
|
||||
(
|
||||
@@ -382,6 +395,7 @@ def test_speculators_model_integration(
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.7, # ref: 75%-80%
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
@@ -393,7 +407,8 @@ def test_speculators_model_integration(
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
marks=large_gpu_mark(min_gb=80),
|
||||
0.8, # ref: 90%
|
||||
# marks=large_gpu_mark(min_gb=80),
|
||||
), # works on 4x H100
|
||||
pytest.param(
|
||||
(
|
||||
@@ -405,6 +420,7 @@ def test_speculators_model_integration(
|
||||
True,
|
||||
True,
|
||||
"auto",
|
||||
0.8, # ref: 90%
|
||||
marks=large_gpu_mark(min_gb=80),
|
||||
), # works on 4x H100
|
||||
(
|
||||
@@ -417,6 +433,7 @@ def test_speculators_model_integration(
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
0.0, # dummy model, skip gsm8k check
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
@@ -437,10 +454,18 @@ def test_eagle_correctness(
|
||||
sampling_config: SamplingParams,
|
||||
model_setup: tuple[str, str, str, int],
|
||||
mm_enabled: bool,
|
||||
expected_accuracy_threshold: float,
|
||||
enable_chunked_prefill: bool,
|
||||
model_impl: str,
|
||||
attn_backend: str,
|
||||
):
|
||||
"""
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
which should be the same when using eagle speculative decoding. Due to some variance
|
||||
in the engine, it is possible for some outputs to differ, so we expect that at least
|
||||
6/10 output tokens match exactly, and that the GSM8k accuracy is above
|
||||
a precomputed reference threshold for each model.
|
||||
"""
|
||||
if attn_backend == "TREE_ATTN":
|
||||
# TODO: Fix this flaky test
|
||||
pytest.skip(
|
||||
@@ -461,11 +486,6 @@ def test_eagle_correctness(
|
||||
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
"""
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using eagle speculative decoding.
|
||||
model_setup: (method, model_name, eagle_model_name, tp_size)
|
||||
"""
|
||||
# Determine attention config
|
||||
# Scout requires default backend selection because vision encoder has
|
||||
# head_dim 88 being incompatible with FLASH_ATTN and needs to fall back
|
||||
@@ -505,6 +525,9 @@ def test_eagle_correctness(
|
||||
tensor_parallel_size=tp_size,
|
||||
attention_config=attention_config,
|
||||
)
|
||||
evaluate_llm_for_gsm8k(
|
||||
ref_llm, expected_accuracy_threshold=expected_accuracy_threshold
|
||||
)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
@@ -526,6 +549,9 @@ def test_eagle_correctness(
|
||||
model_impl=model_impl,
|
||||
attention_config=attention_config,
|
||||
)
|
||||
evaluate_llm_for_gsm8k(
|
||||
spec_llm, expected_accuracy_threshold=expected_accuracy_threshold
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
@@ -546,10 +572,10 @@ def test_eagle_correctness(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["model_setup", "mm_enabled"],
|
||||
["model_setup", "mm_enabled", "expected_accuracy_threshold"],
|
||||
[
|
||||
(("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
|
||||
(("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
|
||||
(("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False, 0.5), # ref: 65%-70%
|
||||
(("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False, 0.0), # dummy model
|
||||
],
|
||||
ids=["mimo", "deepseek"],
|
||||
)
|
||||
@@ -558,14 +584,17 @@ def test_mtp_correctness(
|
||||
sampling_config: SamplingParams,
|
||||
model_setup: tuple[str, str, int],
|
||||
mm_enabled: bool,
|
||||
expected_accuracy_threshold: float,
|
||||
):
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
"""
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using MTP speculative decoding.
|
||||
model_setup: (method, model_name, tp_size)
|
||||
which should be the same when using MTP speculative decoding. Due to some variance
|
||||
in the engine, it is possible for some outputs to differ, so we expect that at least
|
||||
6/10 output tokens match exactly, and that the GSM8k accuracy is above a precomputed
|
||||
reference threshold for each model.
|
||||
"""
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||
|
||||
@@ -579,6 +608,9 @@ def test_mtp_correctness(
|
||||
trust_remote_code=True,
|
||||
)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
evaluate_llm_for_gsm8k(
|
||||
ref_llm, expected_accuracy_threshold=expected_accuracy_threshold
|
||||
)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
@@ -594,6 +626,9 @@ def test_mtp_correctness(
|
||||
},
|
||||
max_model_len=2048,
|
||||
)
|
||||
evaluate_llm_for_gsm8k(
|
||||
spec_llm, expected_accuracy_threshold=expected_accuracy_threshold
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
@@ -621,12 +656,13 @@ class ArgsTest:
|
||||
num_speculative_tokens: int
|
||||
expected_acceptance_rate: float
|
||||
expected_acceptance_len: float
|
||||
expected_gsm8k_accuracy: float = 0.0 # skip by default
|
||||
# Defaults
|
||||
enforce_eager: bool = True
|
||||
parallel_drafting: bool = False
|
||||
target_tensor_parallel_size: int = 1
|
||||
draft_tensor_parallel_size: int = 1
|
||||
max_model_len: int = 1024
|
||||
max_model_len: int = 2048
|
||||
gpu_memory_utilization: float = 0.5
|
||||
dataset: str = "test_prompts"
|
||||
num_prompts: int = 100
|
||||
@@ -639,8 +675,9 @@ cases = [
|
||||
draft_model="Qwen/Qwen3-0.6B",
|
||||
sampling_config=greedy_sampling(),
|
||||
num_speculative_tokens=3, # K
|
||||
expected_acceptance_len=3 + 1, # K + 1
|
||||
expected_acceptance_rate=1.0,
|
||||
expected_acceptance_len=0.98 * (3 + 1), # epsilon discount of K + 1
|
||||
expected_acceptance_rate=0.98, # slight epsilon
|
||||
expected_gsm8k_accuracy=0.25, # ref: 35-40%
|
||||
),
|
||||
# Smaller draft model, stochastic sampling.
|
||||
ArgsTest(
|
||||
@@ -648,8 +685,9 @@ cases = [
|
||||
draft_model="Qwen/Qwen3-0.6B",
|
||||
sampling_config=stochastic_sampling(),
|
||||
num_speculative_tokens=3,
|
||||
expected_acceptance_len=2.8 + 1,
|
||||
expected_acceptance_rate=0.9,
|
||||
expected_acceptance_len=3.4, # ref: 3.7
|
||||
expected_acceptance_rate=0.80, # ref: 0.90
|
||||
expected_gsm8k_accuracy=0.5, # ref: 60%. Note gsm8k always runs greedy sampling
|
||||
),
|
||||
]
|
||||
|
||||
@@ -669,9 +707,8 @@ def test_draft_model_realistic_example():
|
||||
num_speculative_tokens=3,
|
||||
sampling_config=greedy_sampling(),
|
||||
enforce_eager=False,
|
||||
# values below are not derived, but just prevent a regression
|
||||
expected_acceptance_len=2.8,
|
||||
expected_acceptance_rate=0.55,
|
||||
expected_acceptance_len=2.6, # ref: 2.86
|
||||
expected_acceptance_rate=0.5, # ref: 0.62
|
||||
)
|
||||
assert_draft_model_correctness(args)
|
||||
|
||||
@@ -685,9 +722,8 @@ def test_draft_model_parallel_drafting():
|
||||
sampling_config=greedy_sampling(),
|
||||
parallel_drafting=True,
|
||||
enforce_eager=False,
|
||||
# values below are collected from a stable run, with ~5% tolerance
|
||||
expected_acceptance_len=2.375,
|
||||
expected_acceptance_rate=0.45,
|
||||
expected_acceptance_len=2.3, # ref: 2.52
|
||||
expected_acceptance_rate=0.4, # ref: 0.51
|
||||
)
|
||||
assert_draft_model_correctness(args)
|
||||
|
||||
@@ -723,6 +759,7 @@ def test_draft_model_tensor_parallelism():
|
||||
draft_tensor_parallel_size=2,
|
||||
**some_high_acceptance_metrics(),
|
||||
enforce_eager=False,
|
||||
expected_gsm8k_accuracy=0.5,
|
||||
)
|
||||
assert_draft_model_correctness(sd_case)
|
||||
|
||||
@@ -797,9 +834,14 @@ def assert_draft_model_correctness(args: ArgsTest):
|
||||
# we don't check the outputs, only check the metrics
|
||||
spec_llm.chat(test_prompts, args.sampling_config)
|
||||
metrics = spec_llm.get_metrics()
|
||||
|
||||
acceptance_rate: float = compute_acceptance_rate(metrics)
|
||||
acceptance_len: float = compute_acceptance_len(metrics)
|
||||
|
||||
# Need to evaluate after getting metrics to avoid polluting the AR
|
||||
evaluate_llm_for_gsm8k(
|
||||
spec_llm, expected_accuracy_threshold=args.expected_gsm8k_accuracy
|
||||
)
|
||||
|
||||
del spec_llm # CLEANUP
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
@@ -817,7 +859,7 @@ def assert_draft_model_correctness(args: ArgsTest):
|
||||
|
||||
def get_messages(dataset: str, n: int) -> list[Messages]:
|
||||
if dataset == "test_prompts":
|
||||
return get_test_prompts(mm_enabled=False, quiet=True, num_prompts=n)
|
||||
return get_test_prompts(mm_enabled=False, num_prompts=n)
|
||||
elif dataset == "likaixin/InstructCoder":
|
||||
return get_instruct_coder_messages(n=n)
|
||||
else:
|
||||
@@ -828,8 +870,8 @@ def some_high_acceptance_metrics() -> dict:
|
||||
return {
|
||||
"sampling_config": greedy_sampling(),
|
||||
"num_speculative_tokens": 3,
|
||||
"expected_acceptance_len": 2.8 + 1,
|
||||
"expected_acceptance_rate": 0.90,
|
||||
"expected_acceptance_len": 3.4, # ref: 3.75
|
||||
"expected_acceptance_rate": 0.8, # ref: 0.9
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user