[Feature] [Spec decode]: Enable MLPSpeculator/Medusa and prompt_logprobs with ChunkedPrefill (#10132)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: wallashss <wallashss@ibm.com>
Co-authored-by: wallashss <wallashss@ibm.com>
This commit is contained in:
Nicolò Lucchesi
2025-01-27 22:38:35 +01:00
committed by GitHub
parent 2bc3fbba0c
commit 6116ca8cd7
16 changed files with 468 additions and 165 deletions

View File

@@ -2,6 +2,7 @@ from itertools import cycle
from typing import List, Optional, Sequence, Tuple, Union
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
@@ -154,6 +155,8 @@ def _check_logprobs_when_output_disabled(
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
assert spec_pos_logprob.rank == -1
assert spec_pos_logprob.logprob == 0.0
if isinstance(spec_pos_logprob_token_id, torch.Tensor):
spec_pos_logprob_token_id = spec_pos_logprob_token_id.item()
assert spec_pos_logprob_token_id in baseline_pos_logprobs
@@ -244,7 +247,8 @@ def run_equality_correctness_test_tp(model,
batch_size: int,
max_output_len: int,
seed: int = 0,
temperature: float = 0.0):
temperature: float = 0.0,
logprobs: Optional[int] = None):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
@@ -257,7 +261,6 @@ def run_equality_correctness_test_tp(model,
results = []
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
for args, env in ((arg1, env1), (arg2, env2)):
with RemoteOpenAIServer(model,
args,
@@ -269,12 +272,14 @@ def run_equality_correctness_test_tp(model,
prompt=prompts,
max_tokens=max_output_len,
seed=seed,
temperature=temperature)
temperature=temperature,
logprobs=logprobs)
results.append({
"test":
"seeded_sampling",
"text": [choice.text for choice in completion.choices],
"logprobs": [choice.logprobs for choice in completion.choices],
"finish_reason":
[choice.finish_reason for choice in completion.choices],
"usage":
@@ -284,7 +289,15 @@ def run_equality_correctness_test_tp(model,
n = len(results) // 2
arg1_results = results[:n]
arg2_results = results[n:]
# Separate logprobs to avoid asserting exact equality.
arg1_logprobs = [r.pop("logprobs") for r in arg1_results]
arg2_logprobs = [r.pop("logprobs") for r in arg2_results]
for arg1_result, arg2_result in zip(arg1_results, arg2_results):
assert arg1_result == arg2_result, (
f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
f"{arg1_result=} != {arg2_result=}")
if logprobs:
for logs1, logs2 in zip(arg1_logprobs, arg2_logprobs):
for l1, l2 in zip(logs1, logs2):
assert l1.tokens == l2.tokens