[Bugfix] Fix NemotronH MTP + Chunked Prefill (#35447)
This commit is contained in:
committed by
GitHub
parent
20b14095a4
commit
8a680463fa
104
tests/v1/e2e/test_hybrid_chunked_prefill.py
Normal file
104
tests/v1/e2e/test_hybrid_chunked_prefill.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...utils import large_gpu_mark, multi_gpu_marks
|
||||
|
||||
# A trivial request with a short prompt to ensure we run a mixed batch
|
||||
SMALL_MESSAGE = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "The secret beta value is 64. What is the secret beta?",
|
||||
}
|
||||
]
|
||||
|
||||
# Sample prompt with a bunch of filler in between the critical fact and the request.
|
||||
# Both parts need to be processed properly for the model to generate the correct answer
|
||||
MESSAGES = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Important: The secret number is 42. "
|
||||
"The sky is green in this hypothetical world. "
|
||||
"Apples grow on trees in the forest. "
|
||||
"Rivers flow through the valleys and mountains. "
|
||||
"Birds sing songs in the early morning light. "
|
||||
"The weather today is sunny with clear skies ahead. "
|
||||
"Flowers bloom in the garden during spring season. "
|
||||
"Now answer with ONLY the number and nothing else: "
|
||||
"What is the secret number plus one?"
|
||||
),
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[
|
||||
pytest.param("Qwen/Qwen3.5-4B", marks=[large_gpu_mark(min_gb=40)]),
|
||||
pytest.param(
|
||||
"nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-FP8",
|
||||
marks=[large_gpu_mark(min_gb=80)] + multi_gpu_marks(num_gpus=2),
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("enable_prefix_caching", [False, True])
|
||||
def test_mtp_speculative_mixed_batch_short_prefill(
|
||||
vllm_runner, model_name, enable_prefix_caching
|
||||
):
|
||||
"""Test to ensure MTP speculative decoding correctly handles
|
||||
short prefill chunks that fall below the reorder_batch_threshold."""
|
||||
|
||||
# Set so large that both prefills will be classified as decodes in a mixed batch
|
||||
# note, with prefix caching we require chunk_size >= mamba_block_size
|
||||
chunk_size = 256 if not enable_prefix_caching else 16384
|
||||
num_draft_tokens = 100
|
||||
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
speculative_config={
|
||||
"method": "mtp",
|
||||
"num_speculative_tokens": num_draft_tokens,
|
||||
},
|
||||
max_num_batched_tokens=chunk_size,
|
||||
max_model_len=512,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=2,
|
||||
trust_remote_code=True,
|
||||
enable_chunked_prefill=True,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
mamba_cache_mode="align" if enable_prefix_caching else "none",
|
||||
) as llm:
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=128,
|
||||
)
|
||||
|
||||
# First small message gets prefilled first, under normal conditions since the
|
||||
# batch is not yet mixed. Then the second prefill arrives as a mixed batch, but
|
||||
# is shorter than num_speculative_tokens, so it gets misclassified as a decode
|
||||
# and processed with the wrong state management logic, causing the critical
|
||||
# fact from the first chunk to be lost and the model to generate nonsense.
|
||||
outputs = llm.get_llm().chat(
|
||||
[SMALL_MESSAGE, MESSAGES],
|
||||
sampling_params,
|
||||
chat_template_kwargs={"enable_thinking": False},
|
||||
)
|
||||
|
||||
responses = []
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Generated text: {generated_text!r}")
|
||||
responses.append(generated_text)
|
||||
|
||||
assert "64" in responses[0], (
|
||||
"The first response should contain the correct value of 64."
|
||||
)
|
||||
assert "43" in responses[1], (
|
||||
"The second response should contain the correct value of 42+1=43."
|
||||
)
|
||||
Reference in New Issue
Block a user