[Bugfix] Fix NemotronH MTP + Chunked Prefill (#35447)
This commit is contained in:
committed by
GitHub
parent
20b14095a4
commit
8a680463fa
104
tests/v1/e2e/test_hybrid_chunked_prefill.py
Normal file
104
tests/v1/e2e/test_hybrid_chunked_prefill.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...utils import large_gpu_mark, multi_gpu_marks
|
||||
|
||||
# A trivial request with a short prompt to ensure we run a mixed batch
|
||||
SMALL_MESSAGE = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "The secret beta value is 64. What is the secret beta?",
|
||||
}
|
||||
]
|
||||
|
||||
# Sample prompt with a bunch of filler in between the critical fact and the request.
|
||||
# Both parts need to be processed properly for the model to generate the correct answer
|
||||
MESSAGES = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Important: The secret number is 42. "
|
||||
"The sky is green in this hypothetical world. "
|
||||
"Apples grow on trees in the forest. "
|
||||
"Rivers flow through the valleys and mountains. "
|
||||
"Birds sing songs in the early morning light. "
|
||||
"The weather today is sunny with clear skies ahead. "
|
||||
"Flowers bloom in the garden during spring season. "
|
||||
"Now answer with ONLY the number and nothing else: "
|
||||
"What is the secret number plus one?"
|
||||
),
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[
|
||||
pytest.param("Qwen/Qwen3.5-4B", marks=[large_gpu_mark(min_gb=40)]),
|
||||
pytest.param(
|
||||
"nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-FP8",
|
||||
marks=[large_gpu_mark(min_gb=80)] + multi_gpu_marks(num_gpus=2),
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("enable_prefix_caching", [False, True])
|
||||
def test_mtp_speculative_mixed_batch_short_prefill(
|
||||
vllm_runner, model_name, enable_prefix_caching
|
||||
):
|
||||
"""Test to ensure MTP speculative decoding correctly handles
|
||||
short prefill chunks that fall below the reorder_batch_threshold."""
|
||||
|
||||
# Set so large that both prefills will be classified as decodes in a mixed batch
|
||||
# note, with prefix caching we require chunk_size >= mamba_block_size
|
||||
chunk_size = 256 if not enable_prefix_caching else 16384
|
||||
num_draft_tokens = 100
|
||||
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
speculative_config={
|
||||
"method": "mtp",
|
||||
"num_speculative_tokens": num_draft_tokens,
|
||||
},
|
||||
max_num_batched_tokens=chunk_size,
|
||||
max_model_len=512,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=2,
|
||||
trust_remote_code=True,
|
||||
enable_chunked_prefill=True,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
mamba_cache_mode="align" if enable_prefix_caching else "none",
|
||||
) as llm:
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=128,
|
||||
)
|
||||
|
||||
# First small message gets prefilled first, under normal conditions since the
|
||||
# batch is not yet mixed. Then the second prefill arrives as a mixed batch, but
|
||||
# is shorter than num_speculative_tokens, so it gets misclassified as a decode
|
||||
# and processed with the wrong state management logic, causing the critical
|
||||
# fact from the first chunk to be lost and the model to generate nonsense.
|
||||
outputs = llm.get_llm().chat(
|
||||
[SMALL_MESSAGE, MESSAGES],
|
||||
sampling_params,
|
||||
chat_template_kwargs={"enable_thinking": False},
|
||||
)
|
||||
|
||||
responses = []
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Generated text: {generated_text!r}")
|
||||
responses.append(generated_text)
|
||||
|
||||
assert "64" in responses[0], (
|
||||
"The first response should contain the correct value of 64."
|
||||
)
|
||||
assert "43" in responses[1], (
|
||||
"The second response should contain the correct value of 42+1=43."
|
||||
)
|
||||
@@ -334,13 +334,13 @@ def selective_state_update(
|
||||
dt_bias = dt_bias.unsqueeze(0)
|
||||
if out.dim() == 2:
|
||||
out = out.unsqueeze(1)
|
||||
if num_accepted_tokens is not None:
|
||||
assert state_batch_indices is not None and state_batch_indices.dim() == 2
|
||||
assert dst_state_batch_indices is None or dst_state_batch_indices.dim() == 2
|
||||
if state_batch_indices is not None and state_batch_indices.dim() == 1:
|
||||
state_batch_indices = state_batch_indices.unsqueeze(1)
|
||||
if dst_state_batch_indices is not None and dst_state_batch_indices.dim() == 1:
|
||||
dst_state_batch_indices = dst_state_batch_indices.unsqueeze(1)
|
||||
if num_accepted_tokens is not None:
|
||||
assert state_batch_indices is not None and state_batch_indices.dim() == 2
|
||||
assert dst_state_batch_indices is None or dst_state_batch_indices.dim() == 2
|
||||
|
||||
_, nheads, dim, dstate = state.shape
|
||||
batch = x.shape[0]
|
||||
|
||||
@@ -414,8 +414,11 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
]
|
||||
state_indices_tensor_p = state_indices_tensor_p[:, 0]
|
||||
|
||||
if num_decodes > 0 and self.use_spec_decode:
|
||||
assert num_accepted_tokens is not None
|
||||
# Sometimes even with specdec enabled we get single-token prefill chunks that
|
||||
# should be treated as decodes but don't have num_accepted_tokens set.
|
||||
# These should be fine to process as non-spec decodes since there's only
|
||||
# one token, so no risk of placing accepted tokens in the wrong slot.
|
||||
if num_decodes > 0 and self.use_spec_decode and num_accepted_tokens is not None:
|
||||
query_start_loc_d = common_attn_metadata.query_start_loc[: num_decodes + 1]
|
||||
num_accepted_tokens = num_accepted_tokens[:num_decodes]
|
||||
|
||||
@@ -501,9 +504,8 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs]
|
||||
state_indices_tensor_d[metadata.num_decodes :] = PAD_SLOT_ID
|
||||
|
||||
if self.use_spec_decode:
|
||||
if self.use_spec_decode and num_accepted_tokens is not None:
|
||||
assert query_start_loc_d is not None
|
||||
assert num_accepted_tokens is not None
|
||||
query_start_loc_d = query_start_loc_d[: padded_bs + 1]
|
||||
self.decode_num_accepted_tokens[: metadata.num_decodes].copy_(
|
||||
num_accepted_tokens, non_blocking=True
|
||||
|
||||
@@ -739,6 +739,19 @@ class GPUModelRunner(
|
||||
|
||||
self.uniform_decode_query_len = 1 + self.num_spec_tokens
|
||||
|
||||
# When spec decode is active, the mamba backend classifies requests
|
||||
# with query_len <= reorder_batch_threshold as "decodes". Prefill
|
||||
# chunks that fall under this threshold get processed via the decode
|
||||
# path, which stores intermediate states at sequential slots. We must
|
||||
# set num_accepted_tokens to the chunk's query_len for those requests
|
||||
# so the next iteration reads from the correct final-state slot.
|
||||
# Prefills that went through the actual prefill path should keep the
|
||||
# default value of 1 (the prefill path stores state at slot 0 only).
|
||||
self.needs_prefill_as_decode_slots: bool = False
|
||||
self.prefill_as_decode_num_tokens = self._make_buffer(
|
||||
self.max_num_reqs, dtype=torch.int32
|
||||
)
|
||||
|
||||
# Cudagraph dispatcher for runtime cudagraph dispatching.
|
||||
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
|
||||
|
||||
@@ -1355,12 +1368,22 @@ class GPUModelRunner(
|
||||
.int()
|
||||
.argmax(-1)
|
||||
)
|
||||
spec_decode_active = bool(scheduler_output.scheduled_spec_decode_tokens)
|
||||
if self.needs_prefill_as_decode_slots and spec_decode_active:
|
||||
mamba_utils.update_accepted_tokens_for_prefill_as_decode(
|
||||
self.input_batch,
|
||||
self.prefill_as_decode_num_tokens,
|
||||
self.num_accepted_tokens.gpu,
|
||||
scheduler_output,
|
||||
self.reorder_batch_threshold,
|
||||
num_reqs,
|
||||
)
|
||||
|
||||
if self.cache_config.mamba_cache_mode == "align":
|
||||
for i, num_tokens in enumerate(
|
||||
self.num_accepted_tokens.gpu[:num_reqs].cpu().numpy()
|
||||
):
|
||||
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
|
||||
|
||||
mamba_utils.postprocess_mamba(
|
||||
scheduler_output,
|
||||
self.kv_cache_config,
|
||||
@@ -2024,6 +2047,8 @@ class GPUModelRunner(
|
||||
else 0
|
||||
)
|
||||
|
||||
if isinstance(builder, Mamba2AttentionMetadataBuilder):
|
||||
self.needs_prefill_as_decode_slots = True
|
||||
extra_attn_metadata_args = {}
|
||||
if use_spec_decode and isinstance(
|
||||
builder, (Mamba2AttentionMetadataBuilder, GDNAttentionMetadataBuilder)
|
||||
|
||||
@@ -266,3 +266,45 @@ def postprocess_mamba(
|
||||
if src_block_idx == dest_block_idx:
|
||||
num_accepted_tokens_cpu[i] = 1
|
||||
do_mamba_copy_block(copy_bufs)
|
||||
|
||||
|
||||
def update_accepted_tokens_for_prefill_as_decode(
|
||||
input_batch: GPUInputBatch,
|
||||
prefill_as_decode_num_tokens: CpuGpuBuffer,
|
||||
num_accepted_tokens_gpu: torch.Tensor,
|
||||
scheduler_output: SchedulerOutput,
|
||||
decode_qlen_threshold: int | None,
|
||||
num_reqs: int,
|
||||
):
|
||||
"""
|
||||
Adjusts num_accepted_tokens for prefill chunks processed via the decode path.
|
||||
This ensures subsequent iterations read from the correct sequential state slot
|
||||
instead of the default prefill slot 0. Not used by GDN attention, which manually
|
||||
separates short prefills and short decodes when building the attention metadata.
|
||||
"""
|
||||
any_is_prefill = False
|
||||
for i in range(num_reqs):
|
||||
num_computed = input_batch.num_computed_tokens_cpu[i]
|
||||
num_prompt = input_batch.num_prompt_tokens[i]
|
||||
is_prefill = num_computed < num_prompt
|
||||
req_id = input_batch.req_ids[i]
|
||||
query_len = scheduler_output.num_scheduled_tokens[req_id]
|
||||
|
||||
if is_prefill:
|
||||
classified_as_decode = (
|
||||
decode_qlen_threshold is not None and query_len <= decode_qlen_threshold
|
||||
)
|
||||
num_tokens = query_len if classified_as_decode else 1
|
||||
any_is_prefill = True
|
||||
else:
|
||||
num_tokens = -1
|
||||
prefill_as_decode_num_tokens.np[i] = num_tokens
|
||||
|
||||
# We can skip the GPU transfer if there aren't any values to update
|
||||
if any_is_prefill:
|
||||
prefill_as_decode_num_tokens.copy_to_gpu(num_reqs)
|
||||
num_accepted_tokens_gpu[:num_reqs] = torch.where(
|
||||
prefill_as_decode_num_tokens.gpu[:num_reqs] != -1,
|
||||
prefill_as_decode_num_tokens.gpu[:num_reqs],
|
||||
num_accepted_tokens_gpu[:num_reqs],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user