[bugfix] fix bug when top_logprobs=0 with spec decoding (#30059)

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
realliujiaxu
2025-12-13 01:03:35 +08:00
committed by GitHub
parent f3237f3f6b
commit d2c919dcc2
3 changed files with 5 additions and 3 deletions

View File

@@ -528,9 +528,11 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
),
],
)
@pytest.mark.parametrize("top_logprobs", [0, 3])
def test_spec_decode_logprobs(
logprobs_mode: LogprobsMode,
model_setup: tuple[str, str, str],
top_logprobs: int,
):
"""Spec decode logprobs should match those of the base model.
@@ -543,7 +545,7 @@ def test_spec_decode_logprobs(
prompt = "Hello world " * 50
sampling_params = SamplingParams(
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False
)
method, model_name, spec_model_name = model_setup
max_model_len = 256