[Model Runner V2] Support Streaming Inputs (#37028)
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
This commit is contained in:
207
tests/v1/streaming_input/test_gpu_model_runner_v2_streaming.py
Normal file
207
tests/v1/streaming_input/test_gpu_model_runner_v2_streaming.py
Normal file
@@ -0,0 +1,207 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""Unit tests for MRv2 GPUModelRunner.add_requests streaming input support."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.v1.core.sched.output import (
|
||||
CachedRequestData,
|
||||
NewRequestData,
|
||||
SchedulerOutput,
|
||||
)
|
||||
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_runner_with_req_states():
|
||||
"""Create a mock MRv2 GPUModelRunner with a real RequestState."""
|
||||
|
||||
runner = Mock(spec=GPUModelRunner)
|
||||
runner.req_states = RequestState(
|
||||
max_num_reqs=10,
|
||||
max_model_len=1024,
|
||||
max_num_batched_tokens=1024,
|
||||
num_speculative_steps=0,
|
||||
vocab_size=32000,
|
||||
device=torch.device("cpu"),
|
||||
model_dtype=torch.float32,
|
||||
cache_draft_logits=False,
|
||||
)
|
||||
runner.encoder_cache = None
|
||||
runner.model_state = Mock()
|
||||
runner.block_tables = Mock()
|
||||
runner.lora_state = Mock()
|
||||
runner.sampler = None
|
||||
runner.prompt_logprobs_worker = None
|
||||
runner.is_last_pp_rank = False
|
||||
|
||||
# Mock staged writes — they use Triton kernels that require GPU
|
||||
runner.req_states.apply_staged_writes = Mock()
|
||||
|
||||
# Bind the real methods to our mock
|
||||
runner._remove_request = GPUModelRunner._remove_request.__get__(runner)
|
||||
runner.add_requests = GPUModelRunner.add_requests.__get__(runner)
|
||||
return runner
|
||||
|
||||
|
||||
def _make_scheduler_output(new_reqs):
|
||||
return SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs,
|
||||
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||
num_scheduled_tokens={},
|
||||
total_num_scheduled_tokens=0,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
)
|
||||
|
||||
|
||||
def test_e2e_streaming_request_update_basic_flow(
|
||||
mock_model_runner_with_req_states,
|
||||
):
|
||||
"""Test that streaming sessions are updated correctly.
|
||||
|
||||
This test validates that when a streaming session is updated with new
|
||||
prompt tokens:
|
||||
1. The old request state is removed (no free_indices leak)
|
||||
2. The new state is written with updated prefill_token_ids
|
||||
3. model_state and block_tables are re-registered for the new state
|
||||
"""
|
||||
runner = mock_model_runner_with_req_states
|
||||
req_states = runner.req_states
|
||||
req_id = "streaming_req_0"
|
||||
initial_free = len(req_states.free_indices)
|
||||
|
||||
# Step 1: Add initial request with 3 prompt tokens, all computed
|
||||
initial_req_data = NewRequestData(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prefill_token_ids=[1, 2, 3],
|
||||
mm_features=[],
|
||||
sampling_params=None,
|
||||
pooling_params=None,
|
||||
block_ids=([0],),
|
||||
num_computed_tokens=3,
|
||||
lora_request=None,
|
||||
)
|
||||
runner.add_requests(_make_scheduler_output([initial_req_data]))
|
||||
assert req_id in req_states.req_id_to_index
|
||||
assert len(req_states.free_indices) == initial_free - 1
|
||||
|
||||
# Step 2: Create streaming update with extended prompt
|
||||
# The scheduler has already set prefill_token_ids to the full sequence
|
||||
# (original prompt + intermediate output + new prompt tokens)
|
||||
updated_req_data = NewRequestData(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prefill_token_ids=[1, 2, 3, 10, 4, 5],
|
||||
mm_features=[],
|
||||
sampling_params=None,
|
||||
pooling_params=None,
|
||||
block_ids=([0, 1],),
|
||||
num_computed_tokens=4, # 3 original prompt + 1 intermediate output
|
||||
lora_request=None,
|
||||
)
|
||||
runner.add_requests(_make_scheduler_output([updated_req_data]))
|
||||
|
||||
# Step 3: Verify no free_indices leak (old slot recycled)
|
||||
assert len(req_states.free_indices) == initial_free - 1
|
||||
|
||||
# Verify the request is still tracked with exactly one index
|
||||
assert req_id in req_states.req_id_to_index
|
||||
assert sum(1 for v in req_states.index_to_req_id.values() if v == req_id) == 1
|
||||
|
||||
# Verify state was updated with new values
|
||||
new_idx = req_states.req_id_to_index[req_id]
|
||||
assert req_states.prompt_len.np[new_idx] == 3
|
||||
assert req_states.prefill_len.np[new_idx] == 6
|
||||
assert req_states.num_computed_prefill_tokens[new_idx] == 4
|
||||
|
||||
# Verify model_state and block_tables were re-registered
|
||||
runner.model_state.add_request.assert_called_with(new_idx, updated_req_data)
|
||||
runner.block_tables.append_block_ids.assert_called_with(
|
||||
new_idx, ([0, 1],), overwrite=True
|
||||
)
|
||||
|
||||
|
||||
def test_e2e_streaming_with_multimodal_features(
|
||||
mock_model_runner_with_req_states,
|
||||
):
|
||||
"""Test that streaming sessions with multimodal features are updated.
|
||||
|
||||
This test validates that when a streaming session with mm features
|
||||
is updated:
|
||||
1. The old request state is removed (no free_indices leak)
|
||||
2. encoder_cache is cleaned up and re-registered with new mm_features
|
||||
3. model_state is re-registered (recomputes M-RoPE positions etc.)
|
||||
"""
|
||||
runner = mock_model_runner_with_req_states
|
||||
req_states = runner.req_states
|
||||
req_id = "streaming_mm_req_0"
|
||||
initial_free = len(req_states.free_indices)
|
||||
|
||||
# Enable encoder_cache for multimodal
|
||||
runner.encoder_cache = Mock()
|
||||
|
||||
# Step 1: Add initial request with one audio feature
|
||||
mm_feature_1 = Mock()
|
||||
initial_req_data = NewRequestData(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=[1, 2] + [0] * 10 + [3, 4],
|
||||
prefill_token_ids=[1, 2] + [0] * 10 + [3, 4],
|
||||
mm_features=[mm_feature_1],
|
||||
sampling_params=None,
|
||||
pooling_params=None,
|
||||
block_ids=([0],),
|
||||
num_computed_tokens=14,
|
||||
lora_request=None,
|
||||
)
|
||||
runner.add_requests(_make_scheduler_output([initial_req_data]))
|
||||
assert req_id in req_states.req_id_to_index
|
||||
|
||||
# Reset mocks to track only the streaming update calls
|
||||
runner.encoder_cache.reset_mock()
|
||||
runner.model_state.reset_mock()
|
||||
|
||||
# Step 2: Create streaming update with additional multimodal feature
|
||||
# The scheduler has folded the intermediate output (100) into
|
||||
# prefill_token_ids and added a new audio chunk
|
||||
mm_feature_2 = Mock()
|
||||
updated_req_data = NewRequestData(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=[1, 2] + [0] * 10 + [3, 4],
|
||||
prefill_token_ids=[1, 2] + [0] * 10 + [3, 4, 100] + [0] * 5 + [5],
|
||||
mm_features=[mm_feature_1, mm_feature_2],
|
||||
sampling_params=None,
|
||||
pooling_params=None,
|
||||
block_ids=([0, 1],),
|
||||
num_computed_tokens=14,
|
||||
lora_request=None,
|
||||
)
|
||||
runner.add_requests(_make_scheduler_output([updated_req_data]))
|
||||
|
||||
# Step 3: Verify no free_indices leak
|
||||
assert len(req_states.free_indices) == initial_free - 1
|
||||
assert sum(1 for v in req_states.index_to_req_id.values() if v == req_id) == 1
|
||||
|
||||
# Verify encoder_cache was cleaned up and re-registered
|
||||
runner.encoder_cache.remove_request.assert_called_once_with(req_id)
|
||||
runner.encoder_cache.add_request.assert_called_once_with(
|
||||
req_id, [mm_feature_1, mm_feature_2]
|
||||
)
|
||||
|
||||
# Verify model_state was re-registered with new data
|
||||
new_idx = req_states.req_id_to_index[req_id]
|
||||
runner.model_state.add_request.assert_called_once_with(new_idx, updated_req_data)
|
||||
|
||||
# Verify updated prefill length
|
||||
assert req_states.prefill_len.np[new_idx] == 21
|
||||
@@ -150,8 +150,10 @@ def create_whisper_attention_backend_with_block_pooling(
|
||||
new_common_attn_metadata.query_start_loc *= block_pool_size
|
||||
new_common_attn_metadata.query_start_loc_cpu *= block_pool_size
|
||||
new_common_attn_metadata.seq_lens *= block_pool_size
|
||||
new_common_attn_metadata._seq_lens_cpu *= block_pool_size
|
||||
new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size
|
||||
if new_common_attn_metadata._seq_lens_cpu is not None:
|
||||
new_common_attn_metadata._seq_lens_cpu *= block_pool_size
|
||||
if new_common_attn_metadata._num_computed_tokens_cpu is not None:
|
||||
new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size
|
||||
new_common_attn_metadata.num_actual_tokens *= block_pool_size
|
||||
new_common_attn_metadata.max_query_len *= block_pool_size
|
||||
new_common_attn_metadata.max_seq_len *= block_pool_size
|
||||
|
||||
@@ -111,6 +111,7 @@ def _reshape_kv_cache(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor],
|
||||
attn_backends: dict[str, AttentionBackend],
|
||||
cache_dtype: str,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||
@@ -127,6 +128,7 @@ def _reshape_kv_cache(
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
cache_dtype,
|
||||
)
|
||||
|
||||
# FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
|
||||
@@ -155,9 +157,12 @@ def init_kv_cache(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
attn_backends: dict[str, AttentionBackend],
|
||||
device: torch.device,
|
||||
cache_dtype: str,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
|
||||
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends)
|
||||
kv_caches = _reshape_kv_cache(
|
||||
kv_cache_config, kv_cache_raw_tensors, attn_backends, cache_dtype
|
||||
)
|
||||
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
|
||||
return kv_caches
|
||||
|
||||
|
||||
@@ -359,6 +359,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.kv_cache_config,
|
||||
self.attn_backends,
|
||||
self.device,
|
||||
self.cache_config.cache_dtype,
|
||||
)
|
||||
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)
|
||||
|
||||
@@ -555,18 +556,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
return cuda_graph_size
|
||||
|
||||
def _remove_request(self, req_id: str) -> bool:
|
||||
if not self.req_states.remove_request(req_id):
|
||||
return False
|
||||
if self.encoder_cache is not None:
|
||||
self.encoder_cache.remove_request(req_id)
|
||||
if self.prompt_logprobs_worker is not None:
|
||||
self.prompt_logprobs_worker.remove_request(req_id)
|
||||
self.lora_state.remove_request(req_id)
|
||||
return True
|
||||
|
||||
def finish_requests(self, scheduler_output: SchedulerOutput) -> None:
|
||||
finished_req_ids = scheduler_output.finished_req_ids
|
||||
preempted_req_ids = scheduler_output.preempted_req_ids
|
||||
if preempted_req_ids:
|
||||
finished_req_ids = finished_req_ids.union(preempted_req_ids)
|
||||
for req_id in finished_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
if self.encoder_cache is not None:
|
||||
self.encoder_cache.remove_request(req_id)
|
||||
if self.prompt_logprobs_worker is not None:
|
||||
self.prompt_logprobs_worker.remove_request(req_id)
|
||||
self.lora_state.remove_request(req_id)
|
||||
self._remove_request(req_id)
|
||||
|
||||
def free_states(self, scheduler_output: SchedulerOutput) -> None:
|
||||
if self.encoder_cache is not None:
|
||||
@@ -578,6 +584,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
assert new_req_data.prompt_token_ids is not None
|
||||
assert new_req_data.prefill_token_ids is not None
|
||||
req_id = new_req_data.req_id
|
||||
|
||||
# Streaming input update: request already exists from a prior
|
||||
# chunk. Remove old state so it can be cleanly re-added below
|
||||
# with the updated prompt_token_ids and mm_features.
|
||||
self._remove_request(req_id)
|
||||
|
||||
prompt_len = len(new_req_data.prompt_token_ids)
|
||||
self.req_states.add_request(
|
||||
req_id=req_id,
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.tasks import GenerationTask
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
||||
@@ -61,6 +62,28 @@ class DefaultModelState(ModelState):
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def get_supported_generation_tasks(self) -> tuple[GenerationTask, ...]:
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
supports_realtime,
|
||||
supports_transcription,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces_base import is_text_generation_model
|
||||
|
||||
supported_tasks = list[GenerationTask]()
|
||||
|
||||
if is_text_generation_model(self.model):
|
||||
supported_tasks.append("generate")
|
||||
|
||||
if supports_transcription(self.model):
|
||||
if self.model.supports_transcription_only:
|
||||
return ("transcription",)
|
||||
supported_tasks.append("transcription")
|
||||
|
||||
if supports_realtime(self.model):
|
||||
supported_tasks.append("realtime")
|
||||
|
||||
return tuple(supported_tasks)
|
||||
|
||||
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
|
||||
if self.rope_state is not None:
|
||||
assert new_req_data.prefill_token_ids is not None
|
||||
|
||||
@@ -28,8 +28,9 @@ class ModelState(ABC):
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_generation_tasks(self) -> tuple[GenerationTask, ...]:
|
||||
return ("generate",)
|
||||
raise NotImplementedError
|
||||
|
||||
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
|
||||
return None
|
||||
|
||||
@@ -109,13 +109,14 @@ class RequestState:
|
||||
self.all_token_ids.apply_write()
|
||||
self.num_computed_tokens.apply_write()
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
def remove_request(self, req_id: str) -> bool:
|
||||
req_idx = self.req_id_to_index.pop(req_id, None)
|
||||
if req_idx is None:
|
||||
# Request not found.
|
||||
return
|
||||
return False
|
||||
self.index_to_req_id.pop(req_idx, None)
|
||||
self.free_indices.append(req_idx)
|
||||
return True
|
||||
|
||||
def any_prefills(self, idx_mapping_np: np.ndarray) -> bool:
|
||||
return np.any(
|
||||
|
||||
Reference in New Issue
Block a user