[Feature] [Spec decode]: Enable MLPSpeculator/Medusa and prompt_logprobs with ChunkedPrefill (#10132)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: wallashss <wallashss@ibm.com> Co-authored-by: wallashss <wallashss@ibm.com>
This commit is contained in:
@@ -2,6 +2,7 @@ from itertools import cycle
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
@@ -154,6 +155,8 @@ def _check_logprobs_when_output_disabled(
|
||||
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
|
||||
assert spec_pos_logprob.rank == -1
|
||||
assert spec_pos_logprob.logprob == 0.0
|
||||
if isinstance(spec_pos_logprob_token_id, torch.Tensor):
|
||||
spec_pos_logprob_token_id = spec_pos_logprob_token_id.item()
|
||||
assert spec_pos_logprob_token_id in baseline_pos_logprobs
|
||||
|
||||
|
||||
@@ -244,7 +247,8 @@ def run_equality_correctness_test_tp(model,
|
||||
batch_size: int,
|
||||
max_output_len: int,
|
||||
seed: int = 0,
|
||||
temperature: float = 0.0):
|
||||
temperature: float = 0.0,
|
||||
logprobs: Optional[int] = None):
|
||||
"""Helper method that compares the outputs of both the baseline LLM and
|
||||
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||
the same when temperature is zero.
|
||||
@@ -257,7 +261,6 @@ def run_equality_correctness_test_tp(model,
|
||||
results = []
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
|
||||
|
||||
for args, env in ((arg1, env1), (arg2, env2)):
|
||||
with RemoteOpenAIServer(model,
|
||||
args,
|
||||
@@ -269,12 +272,14 @@ def run_equality_correctness_test_tp(model,
|
||||
prompt=prompts,
|
||||
max_tokens=max_output_len,
|
||||
seed=seed,
|
||||
temperature=temperature)
|
||||
temperature=temperature,
|
||||
logprobs=logprobs)
|
||||
|
||||
results.append({
|
||||
"test":
|
||||
"seeded_sampling",
|
||||
"text": [choice.text for choice in completion.choices],
|
||||
"logprobs": [choice.logprobs for choice in completion.choices],
|
||||
"finish_reason":
|
||||
[choice.finish_reason for choice in completion.choices],
|
||||
"usage":
|
||||
@@ -284,7 +289,15 @@ def run_equality_correctness_test_tp(model,
|
||||
n = len(results) // 2
|
||||
arg1_results = results[:n]
|
||||
arg2_results = results[n:]
|
||||
# Separate logprobs to avoid asserting exact equality.
|
||||
arg1_logprobs = [r.pop("logprobs") for r in arg1_results]
|
||||
arg2_logprobs = [r.pop("logprobs") for r in arg2_results]
|
||||
|
||||
for arg1_result, arg2_result in zip(arg1_results, arg2_results):
|
||||
assert arg1_result == arg2_result, (
|
||||
f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
|
||||
f"{arg1_result=} != {arg2_result=}")
|
||||
if logprobs:
|
||||
for logs1, logs2 in zip(arg1_logprobs, arg2_logprobs):
|
||||
for l1, l2 in zip(logs1, logs2):
|
||||
assert l1.tokens == l2.tokens
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
tensor parallelism.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -154,15 +156,20 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
|
||||
"--speculative-draft-tensor-parallel-size",
|
||||
"1",
|
||||
])])
|
||||
@pytest.mark.parametrize("logprobs", [None, 2])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
logprobs: Optional[int],
|
||||
batch_size: int, seed: int):
|
||||
"""Verify spec decode works well with same and different TP size for
|
||||
the draft model with chunked prefill.
|
||||
"""
|
||||
if logprobs:
|
||||
test_llm_kwargs.extend(
|
||||
["--disable_logprobs_during_spec_decoding", "False"])
|
||||
run_equality_correctness_test_tp(model,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@@ -171,4 +178,5 @@ def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=32,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
temperature=0.0,
|
||||
logprobs=logprobs)
|
||||
|
||||
@@ -4,26 +4,27 @@ import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from ..utils import maybe_enable_chunked_prefill
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
"model_name": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
"enforce_eager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
}, {
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
}])
|
||||
@@ -36,12 +37,15 @@ from .conftest import run_equality_correctness_test
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4, 12])
|
||||
def test_logprobs_equality(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int, logprobs: int):
|
||||
"""Verify output logprobs are equal with and without speculative decoding.
|
||||
seed: int, logprobs: int, prefill_chunk_size: int):
|
||||
"""Verify output logprobs are equal with and without speculative decoding,
|
||||
as well as with and without chunked prefill.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
|
||||
@@ -21,6 +21,7 @@ correctess for the target model outputs.
|
||||
|
||||
import pytest
|
||||
|
||||
from ..utils import maybe_enable_chunked_prefill
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
# main model
|
||||
@@ -67,12 +68,14 @@ PRECISION = "float32"
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@@ -119,12 +122,15 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int, logprobs: int):
|
||||
seed: int, logprobs: int,
|
||||
prefill_chunk_size: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@@ -167,12 +173,14 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_medusa_e2e_greedy_correctness_cuda_graph(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality with cuda graph enabled and different
|
||||
batch sizes."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@@ -217,13 +225,15 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_medusa_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@@ -267,13 +277,15 @@ def test_medusa_e2e_greedy_correctness_with_preemption(
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_medusa_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify that medusa speculative decoding produces exact equality
|
||||
to without spec decode with different values of num_speculative_tokens.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@@ -313,14 +325,17 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs,
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
output_len: int, seed: int,
|
||||
prefill_chunk_size: int):
|
||||
"""Verify that medusa speculative decoding produces exact equality
|
||||
to without spec decode when speculation is disabled for large
|
||||
batch sizes.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@@ -361,12 +376,14 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
output_len: int, seed: int, prefill_chunk_size: int):
|
||||
"""Verify that speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
|
||||
@@ -25,6 +25,7 @@ import pytest
|
||||
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size
|
||||
|
||||
from ..utils import maybe_enable_chunked_prefill
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
# main model
|
||||
@@ -66,14 +67,16 @@ PRECISION = "float32"
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("batch_size", [4, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
|
||||
def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@@ -116,12 +119,19 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int):
|
||||
logprobs: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
# NOTE Test is sensitive enough st if we don't enable chunked prefill
|
||||
# scheduling on baseline too, we get slightly different logprobs, ending
|
||||
# up sampling different tokens at the tail (ie top tokens don't change).
|
||||
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@@ -162,12 +172,15 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("output_len", [2048])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int):
|
||||
batch_size: int, output_len: int,
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify acceptance rate with different batch size and large output
|
||||
length."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@@ -204,13 +217,17 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("output_len", [64])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("temperature", [1.0])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
temperature: float, seed: int):
|
||||
temperature: float,
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify seeded runs produce the same output."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@@ -266,14 +283,16 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@@ -317,12 +336,14 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
def test_mlp_e2e_greedy_correctness_with_padding(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify greedy equality when the vocab dimension is padded
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
|
||||
# Default pad_to is 64, test model has vocab_size of 32000
|
||||
def patched_pad_vocab_size(vocab_size, pad_to=None):
|
||||
@@ -373,14 +394,16 @@ def test_mlp_e2e_greedy_correctness_with_padding(
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, seed: int,
|
||||
output_len: int):
|
||||
test_llm_kwargs, batch_size: int,
|
||||
prefill_chunk_size: int, seed: int, output_len: int):
|
||||
"""Verify that mlp speculative decoding produces exact equality
|
||||
to without spec decode with different values of num_speculative_tokens.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@@ -418,15 +441,21 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs,
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
# Speculative decoding is disabled when sequences reach decoding and the batch
|
||||
# consists of single-token requests. Hence we set `max_num_seqs`
|
||||
# >= `speculative_disable_by_batch_size` to test feature interaction.
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, seed: int,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
prefill_chunk_size: int, seed: int,
|
||||
output_len: int):
|
||||
"""Verify that mlp speculative decoding produces exact equality
|
||||
to without spec decode when speculation is disabled for large
|
||||
batch sizes.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@@ -460,13 +489,15 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
output_len: int, prefill_chunk_size: int, seed: int):
|
||||
"""Verify that speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
|
||||
@@ -147,20 +147,20 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": False,
|
||||
"disable_logprobs_during_spec_decoding": False
|
||||
}, {
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
"disable_logprobs_during_spec_decoding": False
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
@@ -192,6 +192,9 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
prompt_logprobs=2,
|
||||
logprobs=2,
|
||||
disable_logprobs=False,
|
||||
temperature=0.0,
|
||||
ensure_all_accepted=ensure_all_accepted)
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ for the target model outputs.
|
||||
|
||||
import pytest
|
||||
|
||||
from ..utils import maybe_enable_chunked_prefill
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
|
||||
@@ -49,11 +50,13 @@ from .conftest import run_equality_correctness_test
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"speculative_disable_mqa_scorer": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@@ -68,15 +71,7 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify greedy equality on a tiny model with different batch size."""
|
||||
if prefill_chunk_size > 0:
|
||||
common_llm_kwargs.update(
|
||||
**{
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": prefill_chunk_size,
|
||||
"max_num_seqs": prefill_chunk_size
|
||||
})
|
||||
else:
|
||||
common_llm_kwargs["enable_chunked_prefill"] = False
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
|
||||
Reference in New Issue
Block a user