[Feature] [Spec decode]: Enable MLPSpeculator/Medusa and prompt_logprobs with ChunkedPrefill (#10132)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: wallashss <wallashss@ibm.com>
Co-authored-by: wallashss <wallashss@ibm.com>
This commit is contained in:
Nicolò Lucchesi
2025-01-27 22:38:35 +01:00
committed by GitHub
parent 2bc3fbba0c
commit 6116ca8cd7
16 changed files with 468 additions and 165 deletions

View File

@@ -2,6 +2,8 @@
tensor parallelism.
"""
from typing import Optional
import pytest
import torch
@@ -154,15 +156,20 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
"--speculative-draft-tensor-parallel-size",
"1",
])])
@pytest.mark.parametrize("logprobs", [None, 2])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
logprobs: Optional[int],
batch_size: int, seed: int):
"""Verify spec decode works well with same and different TP size for
the draft model with chunked prefill.
"""
if logprobs:
test_llm_kwargs.extend(
["--disable_logprobs_during_spec_decoding", "False"])
run_equality_correctness_test_tp(model,
common_llm_kwargs,
per_test_common_llm_kwargs,
@@ -171,4 +178,5 @@ def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
batch_size,
max_output_len=32,
seed=seed,
temperature=0.0)
temperature=0.0,
logprobs=logprobs)