[Model Runner V2] Support Streaming Inputs (#37028)

Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
This commit is contained in:
Santino Ramos
2026-03-20 13:42:25 -07:00
committed by GitHub
parent 8bc6b5cdb0
commit 85f671b8e1
7 changed files with 263 additions and 12 deletions

View File

@@ -0,0 +1,207 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for MRv2 GPUModelRunner.add_requests streaming input support."""
from unittest.mock import Mock
import pytest
import torch
from vllm.v1.core.sched.output import (
CachedRequestData,
NewRequestData,
SchedulerOutput,
)
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
from vllm.v1.worker.gpu.states import RequestState
pytestmark = pytest.mark.cpu_test
@pytest.fixture
def mock_model_runner_with_req_states():
"""Create a mock MRv2 GPUModelRunner with a real RequestState."""
runner = Mock(spec=GPUModelRunner)
runner.req_states = RequestState(
max_num_reqs=10,
max_model_len=1024,
max_num_batched_tokens=1024,
num_speculative_steps=0,
vocab_size=32000,
device=torch.device("cpu"),
model_dtype=torch.float32,
cache_draft_logits=False,
)
runner.encoder_cache = None
runner.model_state = Mock()
runner.block_tables = Mock()
runner.lora_state = Mock()
runner.sampler = None
runner.prompt_logprobs_worker = None
runner.is_last_pp_rank = False
# Mock staged writes — they use Triton kernels that require GPU
runner.req_states.apply_staged_writes = Mock()
# Bind the real methods to our mock
runner._remove_request = GPUModelRunner._remove_request.__get__(runner)
runner.add_requests = GPUModelRunner.add_requests.__get__(runner)
return runner
def _make_scheduler_output(new_reqs):
return SchedulerOutput(
scheduled_new_reqs=new_reqs,
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)
def test_e2e_streaming_request_update_basic_flow(
mock_model_runner_with_req_states,
):
"""Test that streaming sessions are updated correctly.
This test validates that when a streaming session is updated with new
prompt tokens:
1. The old request state is removed (no free_indices leak)
2. The new state is written with updated prefill_token_ids
3. model_state and block_tables are re-registered for the new state
"""
runner = mock_model_runner_with_req_states
req_states = runner.req_states
req_id = "streaming_req_0"
initial_free = len(req_states.free_indices)
# Step 1: Add initial request with 3 prompt tokens, all computed
initial_req_data = NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2, 3],
prefill_token_ids=[1, 2, 3],
mm_features=[],
sampling_params=None,
pooling_params=None,
block_ids=([0],),
num_computed_tokens=3,
lora_request=None,
)
runner.add_requests(_make_scheduler_output([initial_req_data]))
assert req_id in req_states.req_id_to_index
assert len(req_states.free_indices) == initial_free - 1
# Step 2: Create streaming update with extended prompt
# The scheduler has already set prefill_token_ids to the full sequence
# (original prompt + intermediate output + new prompt tokens)
updated_req_data = NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2, 3],
prefill_token_ids=[1, 2, 3, 10, 4, 5],
mm_features=[],
sampling_params=None,
pooling_params=None,
block_ids=([0, 1],),
num_computed_tokens=4, # 3 original prompt + 1 intermediate output
lora_request=None,
)
runner.add_requests(_make_scheduler_output([updated_req_data]))
# Step 3: Verify no free_indices leak (old slot recycled)
assert len(req_states.free_indices) == initial_free - 1
# Verify the request is still tracked with exactly one index
assert req_id in req_states.req_id_to_index
assert sum(1 for v in req_states.index_to_req_id.values() if v == req_id) == 1
# Verify state was updated with new values
new_idx = req_states.req_id_to_index[req_id]
assert req_states.prompt_len.np[new_idx] == 3
assert req_states.prefill_len.np[new_idx] == 6
assert req_states.num_computed_prefill_tokens[new_idx] == 4
# Verify model_state and block_tables were re-registered
runner.model_state.add_request.assert_called_with(new_idx, updated_req_data)
runner.block_tables.append_block_ids.assert_called_with(
new_idx, ([0, 1],), overwrite=True
)
def test_e2e_streaming_with_multimodal_features(
mock_model_runner_with_req_states,
):
"""Test that streaming sessions with multimodal features are updated.
This test validates that when a streaming session with mm features
is updated:
1. The old request state is removed (no free_indices leak)
2. encoder_cache is cleaned up and re-registered with new mm_features
3. model_state is re-registered (recomputes M-RoPE positions etc.)
"""
runner = mock_model_runner_with_req_states
req_states = runner.req_states
req_id = "streaming_mm_req_0"
initial_free = len(req_states.free_indices)
# Enable encoder_cache for multimodal
runner.encoder_cache = Mock()
# Step 1: Add initial request with one audio feature
mm_feature_1 = Mock()
initial_req_data = NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2] + [0] * 10 + [3, 4],
prefill_token_ids=[1, 2] + [0] * 10 + [3, 4],
mm_features=[mm_feature_1],
sampling_params=None,
pooling_params=None,
block_ids=([0],),
num_computed_tokens=14,
lora_request=None,
)
runner.add_requests(_make_scheduler_output([initial_req_data]))
assert req_id in req_states.req_id_to_index
# Reset mocks to track only the streaming update calls
runner.encoder_cache.reset_mock()
runner.model_state.reset_mock()
# Step 2: Create streaming update with additional multimodal feature
# The scheduler has folded the intermediate output (100) into
# prefill_token_ids and added a new audio chunk
mm_feature_2 = Mock()
updated_req_data = NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2] + [0] * 10 + [3, 4],
prefill_token_ids=[1, 2] + [0] * 10 + [3, 4, 100] + [0] * 5 + [5],
mm_features=[mm_feature_1, mm_feature_2],
sampling_params=None,
pooling_params=None,
block_ids=([0, 1],),
num_computed_tokens=14,
lora_request=None,
)
runner.add_requests(_make_scheduler_output([updated_req_data]))
# Step 3: Verify no free_indices leak
assert len(req_states.free_indices) == initial_free - 1
assert sum(1 for v in req_states.index_to_req_id.values() if v == req_id) == 1
# Verify encoder_cache was cleaned up and re-registered
runner.encoder_cache.remove_request.assert_called_once_with(req_id)
runner.encoder_cache.add_request.assert_called_once_with(
req_id, [mm_feature_1, mm_feature_2]
)
# Verify model_state was re-registered with new data
new_idx = req_states.req_id_to_index[req_id]
runner.model_state.add_request.assert_called_once_with(new_idx, updated_req_data)
# Verify updated prefill length
assert req_states.prefill_len.np[new_idx] == 21

View File

@@ -150,8 +150,10 @@ def create_whisper_attention_backend_with_block_pooling(
new_common_attn_metadata.query_start_loc *= block_pool_size
new_common_attn_metadata.query_start_loc_cpu *= block_pool_size
new_common_attn_metadata.seq_lens *= block_pool_size
new_common_attn_metadata._seq_lens_cpu *= block_pool_size
new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size
if new_common_attn_metadata._seq_lens_cpu is not None:
new_common_attn_metadata._seq_lens_cpu *= block_pool_size
if new_common_attn_metadata._num_computed_tokens_cpu is not None:
new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size
new_common_attn_metadata.num_actual_tokens *= block_pool_size
new_common_attn_metadata.max_query_len *= block_pool_size
new_common_attn_metadata.max_seq_len *= block_pool_size

View File

@@ -111,6 +111,7 @@ def _reshape_kv_cache(
kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor],
attn_backends: dict[str, AttentionBackend],
cache_dtype: str,
) -> dict[str, torch.Tensor]:
kv_caches: dict[str, torch.Tensor] = {}
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
@@ -127,6 +128,7 @@ def _reshape_kv_cache(
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype,
)
# FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
@@ -155,9 +157,12 @@ def init_kv_cache(
kv_cache_config: KVCacheConfig,
attn_backends: dict[str, AttentionBackend],
device: torch.device,
cache_dtype: str,
) -> dict[str, torch.Tensor]:
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends)
kv_caches = _reshape_kv_cache(
kv_cache_config, kv_cache_raw_tensors, attn_backends, cache_dtype
)
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
return kv_caches

View File

@@ -359,6 +359,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.kv_cache_config,
self.attn_backends,
self.device,
self.cache_config.cache_dtype,
)
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)
@@ -555,18 +556,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
return cuda_graph_size
def _remove_request(self, req_id: str) -> bool:
if not self.req_states.remove_request(req_id):
return False
if self.encoder_cache is not None:
self.encoder_cache.remove_request(req_id)
if self.prompt_logprobs_worker is not None:
self.prompt_logprobs_worker.remove_request(req_id)
self.lora_state.remove_request(req_id)
return True
def finish_requests(self, scheduler_output: SchedulerOutput) -> None:
finished_req_ids = scheduler_output.finished_req_ids
preempted_req_ids = scheduler_output.preempted_req_ids
if preempted_req_ids:
finished_req_ids = finished_req_ids.union(preempted_req_ids)
for req_id in finished_req_ids:
self.req_states.remove_request(req_id)
if self.encoder_cache is not None:
self.encoder_cache.remove_request(req_id)
if self.prompt_logprobs_worker is not None:
self.prompt_logprobs_worker.remove_request(req_id)
self.lora_state.remove_request(req_id)
self._remove_request(req_id)
def free_states(self, scheduler_output: SchedulerOutput) -> None:
if self.encoder_cache is not None:
@@ -578,6 +584,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert new_req_data.prompt_token_ids is not None
assert new_req_data.prefill_token_ids is not None
req_id = new_req_data.req_id
# Streaming input update: request already exists from a prior
# chunk. Remove old state so it can be cleanly re-added below
# with the updated prompt_token_ids and mm_features.
self._remove_request(req_id)
prompt_len = len(new_req_data.prompt_token_ids)
self.req_states.add_request(
req_id=req_id,

View File

@@ -7,6 +7,7 @@ import torch.nn as nn
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.tasks import GenerationTask
from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
@@ -61,6 +62,28 @@ class DefaultModelState(ModelState):
device=self.device,
)
def get_supported_generation_tasks(self) -> tuple[GenerationTask, ...]:
from vllm.model_executor.models.interfaces import (
supports_realtime,
supports_transcription,
)
from vllm.model_executor.models.interfaces_base import is_text_generation_model
supported_tasks = list[GenerationTask]()
if is_text_generation_model(self.model):
supported_tasks.append("generate")
if supports_transcription(self.model):
if self.model.supports_transcription_only:
return ("transcription",)
supported_tasks.append("transcription")
if supports_realtime(self.model):
supported_tasks.append("realtime")
return tuple(supported_tasks)
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
if self.rope_state is not None:
assert new_req_data.prefill_token_ids is not None

View File

@@ -28,8 +28,9 @@ class ModelState(ABC):
) -> None:
raise NotImplementedError
@abstractmethod
def get_supported_generation_tasks(self) -> tuple[GenerationTask, ...]:
return ("generate",)
raise NotImplementedError
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
return None

View File

@@ -109,13 +109,14 @@ class RequestState:
self.all_token_ids.apply_write()
self.num_computed_tokens.apply_write()
def remove_request(self, req_id: str) -> None:
def remove_request(self, req_id: str) -> bool:
req_idx = self.req_id_to_index.pop(req_id, None)
if req_idx is None:
# Request not found.
return
return False
self.index_to_req_id.pop(req_idx, None)
self.free_indices.append(req_idx)
return True
def any_prefills(self, idx_mapping_np: np.ndarray) -> bool:
return np.any(