diff --git a/tests/v1/streaming_input/test_gpu_model_runner_v2_streaming.py b/tests/v1/streaming_input/test_gpu_model_runner_v2_streaming.py new file mode 100644 index 000000000..8fde0f117 --- /dev/null +++ b/tests/v1/streaming_input/test_gpu_model_runner_v2_streaming.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Unit tests for MRv2 GPUModelRunner.add_requests streaming input support.""" + +from unittest.mock import Mock + +import pytest +import torch + +from vllm.v1.core.sched.output import ( + CachedRequestData, + NewRequestData, + SchedulerOutput, +) +from vllm.v1.worker.gpu.model_runner import GPUModelRunner +from vllm.v1.worker.gpu.states import RequestState + +pytestmark = pytest.mark.cpu_test + + +@pytest.fixture +def mock_model_runner_with_req_states(): + """Create a mock MRv2 GPUModelRunner with a real RequestState.""" + + runner = Mock(spec=GPUModelRunner) + runner.req_states = RequestState( + max_num_reqs=10, + max_model_len=1024, + max_num_batched_tokens=1024, + num_speculative_steps=0, + vocab_size=32000, + device=torch.device("cpu"), + model_dtype=torch.float32, + cache_draft_logits=False, + ) + runner.encoder_cache = None + runner.model_state = Mock() + runner.block_tables = Mock() + runner.lora_state = Mock() + runner.sampler = None + runner.prompt_logprobs_worker = None + runner.is_last_pp_rank = False + + # Mock staged writes — they use Triton kernels that require GPU + runner.req_states.apply_staged_writes = Mock() + + # Bind the real methods to our mock + runner._remove_request = GPUModelRunner._remove_request.__get__(runner) + runner.add_requests = GPUModelRunner.add_requests.__get__(runner) + return runner + + +def _make_scheduler_output(new_reqs): + return SchedulerOutput( + scheduled_new_reqs=new_reqs, + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={}, + total_num_scheduled_tokens=0, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[], + finished_req_ids=set(), + free_encoder_mm_hashes=[], + ) + + +def test_e2e_streaming_request_update_basic_flow( + mock_model_runner_with_req_states, +): + """Test that streaming sessions are updated correctly. + + This test validates that when a streaming session is updated with new + prompt tokens: + 1. The old request state is removed (no free_indices leak) + 2. The new state is written with updated prefill_token_ids + 3. model_state and block_tables are re-registered for the new state + """ + runner = mock_model_runner_with_req_states + req_states = runner.req_states + req_id = "streaming_req_0" + initial_free = len(req_states.free_indices) + + # Step 1: Add initial request with 3 prompt tokens, all computed + initial_req_data = NewRequestData( + req_id=req_id, + prompt_token_ids=[1, 2, 3], + prefill_token_ids=[1, 2, 3], + mm_features=[], + sampling_params=None, + pooling_params=None, + block_ids=([0],), + num_computed_tokens=3, + lora_request=None, + ) + runner.add_requests(_make_scheduler_output([initial_req_data])) + assert req_id in req_states.req_id_to_index + assert len(req_states.free_indices) == initial_free - 1 + + # Step 2: Create streaming update with extended prompt + # The scheduler has already set prefill_token_ids to the full sequence + # (original prompt + intermediate output + new prompt tokens) + updated_req_data = NewRequestData( + req_id=req_id, + prompt_token_ids=[1, 2, 3], + prefill_token_ids=[1, 2, 3, 10, 4, 5], + mm_features=[], + sampling_params=None, + pooling_params=None, + block_ids=([0, 1],), + num_computed_tokens=4, # 3 original prompt + 1 intermediate output + lora_request=None, + ) + runner.add_requests(_make_scheduler_output([updated_req_data])) + + # Step 3: Verify no free_indices leak (old slot recycled) + assert len(req_states.free_indices) == initial_free - 1 + + # Verify the request is still tracked with exactly one index + assert req_id in req_states.req_id_to_index + assert sum(1 for v in req_states.index_to_req_id.values() if v == req_id) == 1 + + # Verify state was updated with new values + new_idx = req_states.req_id_to_index[req_id] + assert req_states.prompt_len.np[new_idx] == 3 + assert req_states.prefill_len.np[new_idx] == 6 + assert req_states.num_computed_prefill_tokens[new_idx] == 4 + + # Verify model_state and block_tables were re-registered + runner.model_state.add_request.assert_called_with(new_idx, updated_req_data) + runner.block_tables.append_block_ids.assert_called_with( + new_idx, ([0, 1],), overwrite=True + ) + + +def test_e2e_streaming_with_multimodal_features( + mock_model_runner_with_req_states, +): + """Test that streaming sessions with multimodal features are updated. + + This test validates that when a streaming session with mm features + is updated: + 1. The old request state is removed (no free_indices leak) + 2. encoder_cache is cleaned up and re-registered with new mm_features + 3. model_state is re-registered (recomputes M-RoPE positions etc.) + """ + runner = mock_model_runner_with_req_states + req_states = runner.req_states + req_id = "streaming_mm_req_0" + initial_free = len(req_states.free_indices) + + # Enable encoder_cache for multimodal + runner.encoder_cache = Mock() + + # Step 1: Add initial request with one audio feature + mm_feature_1 = Mock() + initial_req_data = NewRequestData( + req_id=req_id, + prompt_token_ids=[1, 2] + [0] * 10 + [3, 4], + prefill_token_ids=[1, 2] + [0] * 10 + [3, 4], + mm_features=[mm_feature_1], + sampling_params=None, + pooling_params=None, + block_ids=([0],), + num_computed_tokens=14, + lora_request=None, + ) + runner.add_requests(_make_scheduler_output([initial_req_data])) + assert req_id in req_states.req_id_to_index + + # Reset mocks to track only the streaming update calls + runner.encoder_cache.reset_mock() + runner.model_state.reset_mock() + + # Step 2: Create streaming update with additional multimodal feature + # The scheduler has folded the intermediate output (100) into + # prefill_token_ids and added a new audio chunk + mm_feature_2 = Mock() + updated_req_data = NewRequestData( + req_id=req_id, + prompt_token_ids=[1, 2] + [0] * 10 + [3, 4], + prefill_token_ids=[1, 2] + [0] * 10 + [3, 4, 100] + [0] * 5 + [5], + mm_features=[mm_feature_1, mm_feature_2], + sampling_params=None, + pooling_params=None, + block_ids=([0, 1],), + num_computed_tokens=14, + lora_request=None, + ) + runner.add_requests(_make_scheduler_output([updated_req_data])) + + # Step 3: Verify no free_indices leak + assert len(req_states.free_indices) == initial_free - 1 + assert sum(1 for v in req_states.index_to_req_id.values() if v == req_id) == 1 + + # Verify encoder_cache was cleaned up and re-registered + runner.encoder_cache.remove_request.assert_called_once_with(req_id) + runner.encoder_cache.add_request.assert_called_once_with( + req_id, [mm_feature_1, mm_feature_2] + ) + + # Verify model_state was re-registered with new data + new_idx = req_states.req_id_to_index[req_id] + runner.model_state.add_request.assert_called_once_with(new_idx, updated_req_data) + + # Verify updated prefill length + assert req_states.prefill_len.np[new_idx] == 21 diff --git a/vllm/model_executor/models/whisper_causal.py b/vllm/model_executor/models/whisper_causal.py index 6774ea11d..8e4322ea3 100644 --- a/vllm/model_executor/models/whisper_causal.py +++ b/vllm/model_executor/models/whisper_causal.py @@ -150,8 +150,10 @@ def create_whisper_attention_backend_with_block_pooling( new_common_attn_metadata.query_start_loc *= block_pool_size new_common_attn_metadata.query_start_loc_cpu *= block_pool_size new_common_attn_metadata.seq_lens *= block_pool_size - new_common_attn_metadata._seq_lens_cpu *= block_pool_size - new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size + if new_common_attn_metadata._seq_lens_cpu is not None: + new_common_attn_metadata._seq_lens_cpu *= block_pool_size + if new_common_attn_metadata._num_computed_tokens_cpu is not None: + new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size new_common_attn_metadata.num_actual_tokens *= block_pool_size new_common_attn_metadata.max_query_len *= block_pool_size new_common_attn_metadata.max_seq_len *= block_pool_size diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 59786ed7a..8e5bb11e4 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -111,6 +111,7 @@ def _reshape_kv_cache( kv_cache_config: KVCacheConfig, kv_cache_raw_tensors: dict[str, torch.Tensor], attn_backends: dict[str, AttentionBackend], + cache_dtype: str, ) -> dict[str, torch.Tensor]: kv_caches: dict[str, torch.Tensor] = {} for kv_cache_group_spec in kv_cache_config.kv_cache_groups: @@ -127,6 +128,7 @@ def _reshape_kv_cache( kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, + cache_dtype, ) # FIXME(woosuk): Add kv_cache_stride_order to all attention backends. @@ -155,9 +157,12 @@ def init_kv_cache( kv_cache_config: KVCacheConfig, attn_backends: dict[str, AttentionBackend], device: torch.device, + cache_dtype: str, ) -> dict[str, torch.Tensor]: kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device) - kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends) + kv_caches = _reshape_kv_cache( + kv_cache_config, kv_cache_raw_tensors, attn_backends, cache_dtype + ) bind_kv_cache(kv_caches, forward_context, runner_kv_caches) return kv_caches diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 8051442d2..5788b31d2 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -359,6 +359,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.kv_cache_config, self.attn_backends, self.device, + self.cache_config.cache_dtype, ) self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict) @@ -555,18 +556,23 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) return cuda_graph_size + def _remove_request(self, req_id: str) -> bool: + if not self.req_states.remove_request(req_id): + return False + if self.encoder_cache is not None: + self.encoder_cache.remove_request(req_id) + if self.prompt_logprobs_worker is not None: + self.prompt_logprobs_worker.remove_request(req_id) + self.lora_state.remove_request(req_id) + return True + def finish_requests(self, scheduler_output: SchedulerOutput) -> None: finished_req_ids = scheduler_output.finished_req_ids preempted_req_ids = scheduler_output.preempted_req_ids if preempted_req_ids: finished_req_ids = finished_req_ids.union(preempted_req_ids) for req_id in finished_req_ids: - self.req_states.remove_request(req_id) - if self.encoder_cache is not None: - self.encoder_cache.remove_request(req_id) - if self.prompt_logprobs_worker is not None: - self.prompt_logprobs_worker.remove_request(req_id) - self.lora_state.remove_request(req_id) + self._remove_request(req_id) def free_states(self, scheduler_output: SchedulerOutput) -> None: if self.encoder_cache is not None: @@ -578,6 +584,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): assert new_req_data.prompt_token_ids is not None assert new_req_data.prefill_token_ids is not None req_id = new_req_data.req_id + + # Streaming input update: request already exists from a prior + # chunk. Remove old state so it can be cleanly re-added below + # with the updated prompt_token_ids and mm_features. + self._remove_request(req_id) + prompt_len = len(new_req_data.prompt_token_ids) self.req_states.add_request( req_id=req_id, diff --git a/vllm/v1/worker/gpu/model_states/default.py b/vllm/v1/worker/gpu/model_states/default.py index 104e4c194..8e73867de 100644 --- a/vllm/v1/worker/gpu/model_states/default.py +++ b/vllm/v1/worker/gpu/model_states/default.py @@ -7,6 +7,7 @@ import torch.nn as nn from vllm.config import VllmConfig from vllm.config.compilation import CUDAGraphMode +from vllm.tasks import GenerationTask from vllm.v1.core.sched.output import NewRequestData from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.attn_utils import build_attn_metadata @@ -61,6 +62,28 @@ class DefaultModelState(ModelState): device=self.device, ) + def get_supported_generation_tasks(self) -> tuple[GenerationTask, ...]: + from vllm.model_executor.models.interfaces import ( + supports_realtime, + supports_transcription, + ) + from vllm.model_executor.models.interfaces_base import is_text_generation_model + + supported_tasks = list[GenerationTask]() + + if is_text_generation_model(self.model): + supported_tasks.append("generate") + + if supports_transcription(self.model): + if self.model.supports_transcription_only: + return ("transcription",) + supported_tasks.append("transcription") + + if supports_realtime(self.model): + supported_tasks.append("realtime") + + return tuple(supported_tasks) + def add_request(self, req_index: int, new_req_data: NewRequestData) -> None: if self.rope_state is not None: assert new_req_data.prefill_token_ids is not None diff --git a/vllm/v1/worker/gpu/model_states/interface.py b/vllm/v1/worker/gpu/model_states/interface.py index 1c114496d..d83ab2fc5 100644 --- a/vllm/v1/worker/gpu/model_states/interface.py +++ b/vllm/v1/worker/gpu/model_states/interface.py @@ -28,8 +28,9 @@ class ModelState(ABC): ) -> None: raise NotImplementedError + @abstractmethod def get_supported_generation_tasks(self) -> tuple[GenerationTask, ...]: - return ("generate",) + raise NotImplementedError def add_request(self, req_index: int, new_req_data: NewRequestData) -> None: return None diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index f929b5edd..24d225886 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -109,13 +109,14 @@ class RequestState: self.all_token_ids.apply_write() self.num_computed_tokens.apply_write() - def remove_request(self, req_id: str) -> None: + def remove_request(self, req_id: str) -> bool: req_idx = self.req_id_to_index.pop(req_id, None) if req_idx is None: # Request not found. - return + return False self.index_to_req_id.pop(req_idx, None) self.free_indices.append(req_idx) + return True def any_prefills(self, idx_mapping_np: np.ndarray) -> bool: return np.any(