diff --git a/tests/v1/e2e/test_hybrid_chunked_prefill.py b/tests/v1/e2e/test_hybrid_chunked_prefill.py new file mode 100644 index 000000000..030081a38 --- /dev/null +++ b/tests/v1/e2e/test_hybrid_chunked_prefill.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm import SamplingParams +from vllm.platforms import current_platform + +from ...utils import large_gpu_mark, multi_gpu_marks + +# A trivial request with a short prompt to ensure we run a mixed batch +SMALL_MESSAGE = [ + { + "role": "user", + "content": "The secret beta value is 64. What is the secret beta?", + } +] + +# Sample prompt with a bunch of filler in between the critical fact and the request. +# Both parts need to be processed properly for the model to generate the correct answer +MESSAGES = [ + { + "role": "user", + "content": ( + "Important: The secret number is 42. " + "The sky is green in this hypothetical world. " + "Apples grow on trees in the forest. " + "Rivers flow through the valleys and mountains. " + "Birds sing songs in the early morning light. " + "The weather today is sunny with clear skies ahead. " + "Flowers bloom in the garden during spring season. " + "Now answer with ONLY the number and nothing else: " + "What is the secret number plus one?" + ), + } +] + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available") +@pytest.mark.parametrize( + "model_name", + [ + pytest.param("Qwen/Qwen3.5-4B", marks=[large_gpu_mark(min_gb=40)]), + pytest.param( + "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-FP8", + marks=[large_gpu_mark(min_gb=80)] + multi_gpu_marks(num_gpus=2), + ), + ], +) +@pytest.mark.parametrize("enable_prefix_caching", [False, True]) +def test_mtp_speculative_mixed_batch_short_prefill( + vllm_runner, model_name, enable_prefix_caching +): + """Test to ensure MTP speculative decoding correctly handles + short prefill chunks that fall below the reorder_batch_threshold.""" + + # Set so large that both prefills will be classified as decodes in a mixed batch + # note, with prefix caching we require chunk_size >= mamba_block_size + chunk_size = 256 if not enable_prefix_caching else 16384 + num_draft_tokens = 100 + + with vllm_runner( + model_name, + speculative_config={ + "method": "mtp", + "num_speculative_tokens": num_draft_tokens, + }, + max_num_batched_tokens=chunk_size, + max_model_len=512, + enforce_eager=True, + tensor_parallel_size=2, + trust_remote_code=True, + enable_chunked_prefill=True, + enable_prefix_caching=enable_prefix_caching, + mamba_cache_mode="align" if enable_prefix_caching else "none", + ) as llm: + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=128, + ) + + # First small message gets prefilled first, under normal conditions since the + # batch is not yet mixed. Then the second prefill arrives as a mixed batch, but + # is shorter than num_speculative_tokens, so it gets misclassified as a decode + # and processed with the wrong state management logic, causing the critical + # fact from the first chunk to be lost and the model to generate nonsense. + outputs = llm.get_llm().chat( + [SMALL_MESSAGE, MESSAGES], + sampling_params, + chat_template_kwargs={"enable_thinking": False}, + ) + + responses = [] + for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") + responses.append(generated_text) + + assert "64" in responses[0], ( + "The first response should contain the correct value of 64." + ) + assert "43" in responses[1], ( + "The second response should contain the correct value of 42+1=43." + ) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 22a99596a..1cd077758 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -334,13 +334,13 @@ def selective_state_update( dt_bias = dt_bias.unsqueeze(0) if out.dim() == 2: out = out.unsqueeze(1) - if num_accepted_tokens is not None: - assert state_batch_indices is not None and state_batch_indices.dim() == 2 - assert dst_state_batch_indices is None or dst_state_batch_indices.dim() == 2 if state_batch_indices is not None and state_batch_indices.dim() == 1: state_batch_indices = state_batch_indices.unsqueeze(1) if dst_state_batch_indices is not None and dst_state_batch_indices.dim() == 1: dst_state_batch_indices = dst_state_batch_indices.unsqueeze(1) + if num_accepted_tokens is not None: + assert state_batch_indices is not None and state_batch_indices.dim() == 2 + assert dst_state_batch_indices is None or dst_state_batch_indices.dim() == 2 _, nheads, dim, dstate = state.shape batch = x.shape[0] diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 0364d6aee..bdb820eac 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -414,8 +414,11 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ] state_indices_tensor_p = state_indices_tensor_p[:, 0] - if num_decodes > 0 and self.use_spec_decode: - assert num_accepted_tokens is not None + # Sometimes even with specdec enabled we get single-token prefill chunks that + # should be treated as decodes but don't have num_accepted_tokens set. + # These should be fine to process as non-spec decodes since there's only + # one token, so no risk of placing accepted tokens in the wrong slot. + if num_decodes > 0 and self.use_spec_decode and num_accepted_tokens is not None: query_start_loc_d = common_attn_metadata.query_start_loc[: num_decodes + 1] num_accepted_tokens = num_accepted_tokens[:num_decodes] @@ -501,9 +504,8 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs] state_indices_tensor_d[metadata.num_decodes :] = PAD_SLOT_ID - if self.use_spec_decode: + if self.use_spec_decode and num_accepted_tokens is not None: assert query_start_loc_d is not None - assert num_accepted_tokens is not None query_start_loc_d = query_start_loc_d[: padded_bs + 1] self.decode_num_accepted_tokens[: metadata.num_decodes].copy_( num_accepted_tokens, non_blocking=True diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 98e1dab36..22459bc49 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -739,6 +739,19 @@ class GPUModelRunner( self.uniform_decode_query_len = 1 + self.num_spec_tokens + # When spec decode is active, the mamba backend classifies requests + # with query_len <= reorder_batch_threshold as "decodes". Prefill + # chunks that fall under this threshold get processed via the decode + # path, which stores intermediate states at sequential slots. We must + # set num_accepted_tokens to the chunk's query_len for those requests + # so the next iteration reads from the correct final-state slot. + # Prefills that went through the actual prefill path should keep the + # default value of 1 (the prefill path stores state at slot 0 only). + self.needs_prefill_as_decode_slots: bool = False + self.prefill_as_decode_num_tokens = self._make_buffer( + self.max_num_reqs, dtype=torch.int32 + ) + # Cudagraph dispatcher for runtime cudagraph dispatching. self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) @@ -1355,12 +1368,22 @@ class GPUModelRunner( .int() .argmax(-1) ) + spec_decode_active = bool(scheduler_output.scheduled_spec_decode_tokens) + if self.needs_prefill_as_decode_slots and spec_decode_active: + mamba_utils.update_accepted_tokens_for_prefill_as_decode( + self.input_batch, + self.prefill_as_decode_num_tokens, + self.num_accepted_tokens.gpu, + scheduler_output, + self.reorder_batch_threshold, + num_reqs, + ) + if self.cache_config.mamba_cache_mode == "align": for i, num_tokens in enumerate( self.num_accepted_tokens.gpu[:num_reqs].cpu().numpy() ): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens - mamba_utils.postprocess_mamba( scheduler_output, self.kv_cache_config, @@ -2024,6 +2047,8 @@ class GPUModelRunner( else 0 ) + if isinstance(builder, Mamba2AttentionMetadataBuilder): + self.needs_prefill_as_decode_slots = True extra_attn_metadata_args = {} if use_spec_decode and isinstance( builder, (Mamba2AttentionMetadataBuilder, GDNAttentionMetadataBuilder) diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index 2bd5d2b3f..68172133e 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -266,3 +266,45 @@ def postprocess_mamba( if src_block_idx == dest_block_idx: num_accepted_tokens_cpu[i] = 1 do_mamba_copy_block(copy_bufs) + + +def update_accepted_tokens_for_prefill_as_decode( + input_batch: GPUInputBatch, + prefill_as_decode_num_tokens: CpuGpuBuffer, + num_accepted_tokens_gpu: torch.Tensor, + scheduler_output: SchedulerOutput, + decode_qlen_threshold: int | None, + num_reqs: int, +): + """ + Adjusts num_accepted_tokens for prefill chunks processed via the decode path. + This ensures subsequent iterations read from the correct sequential state slot + instead of the default prefill slot 0. Not used by GDN attention, which manually + separates short prefills and short decodes when building the attention metadata. + """ + any_is_prefill = False + for i in range(num_reqs): + num_computed = input_batch.num_computed_tokens_cpu[i] + num_prompt = input_batch.num_prompt_tokens[i] + is_prefill = num_computed < num_prompt + req_id = input_batch.req_ids[i] + query_len = scheduler_output.num_scheduled_tokens[req_id] + + if is_prefill: + classified_as_decode = ( + decode_qlen_threshold is not None and query_len <= decode_qlen_threshold + ) + num_tokens = query_len if classified_as_decode else 1 + any_is_prefill = True + else: + num_tokens = -1 + prefill_as_decode_num_tokens.np[i] = num_tokens + + # We can skip the GPU transfer if there aren't any values to update + if any_is_prefill: + prefill_as_decode_num_tokens.copy_to_gpu(num_reqs) + num_accepted_tokens_gpu[:num_reqs] = torch.where( + prefill_as_decode_num_tokens.gpu[:num_reqs] != -1, + prefill_as_decode_num_tokens.gpu[:num_reqs], + num_accepted_tokens_gpu[:num_reqs], + )