[Feature] [Spec decode]: Enable MLPSpeculator/Medusa and prompt_logprobs with ChunkedPrefill (#10132)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: wallashss <wallashss@ibm.com> Co-authored-by: wallashss <wallashss@ibm.com>
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
tensor parallelism.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -154,15 +156,20 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
|
||||
"--speculative-draft-tensor-parallel-size",
|
||||
"1",
|
||||
])])
|
||||
@pytest.mark.parametrize("logprobs", [None, 2])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
logprobs: Optional[int],
|
||||
batch_size: int, seed: int):
|
||||
"""Verify spec decode works well with same and different TP size for
|
||||
the draft model with chunked prefill.
|
||||
"""
|
||||
if logprobs:
|
||||
test_llm_kwargs.extend(
|
||||
["--disable_logprobs_during_spec_decoding", "False"])
|
||||
run_equality_correctness_test_tp(model,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@@ -171,4 +178,5 @@ def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=32,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
temperature=0.0,
|
||||
logprobs=logprobs)
|
||||
|
||||
Reference in New Issue
Block a user