[Feature] add session based streaming input support to v1 (#28973)

Signed-off-by: Joshua Deng <joshuakdeng@gmail.com>
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Signed-off-by: Nick Hill <nickhill123@gmail.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Joshua Deng
2026-01-24 13:06:28 -07:00
committed by GitHub
parent d4dbb7af63
commit 91601ff478
16 changed files with 2151 additions and 63 deletions

View File

View File

@@ -0,0 +1,171 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from collections.abc import AsyncGenerator
from unittest.mock import AsyncMock, MagicMock
import pytest
from vllm.outputs import RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.engine.async_llm import AsyncLLM, StreamingInput
from vllm.v1.engine.output_processor import RequestOutputCollector
@pytest.fixture
def mock_async_llm():
"""Create a mock AsyncLLM with mocked dependencies."""
# Create a minimal mock without initializing the full engine
llm = MagicMock(spec=AsyncLLM)
# Mock the essential attributes
llm.vllm_config = MagicMock()
llm.vllm_config.cache_config.kv_sharing_fast_prefill = False
llm.model_config = MagicMock()
llm.model_config.max_model_len = 2048
llm.log_requests = False
llm.errored = False
llm._pause_cond = asyncio.Condition()
llm._paused = False
# Mock methods
llm._run_output_handler = MagicMock()
llm.abort = AsyncMock()
# Use the real generate method from AsyncLLM
llm.generate = AsyncLLM.generate.__get__(llm, AsyncLLM)
return llm
@pytest.mark.asyncio
async def test_generate_normal_flow(mock_async_llm):
"""Test normal generation flow with streaming requests."""
request_id = "test_request"
prompt = "Tell me about Paris"
sampling_params = SamplingParams(max_tokens=10)
# Create a mock queue with outputs
queue = RequestOutputCollector(RequestOutputKind.FINAL_ONLY, request_id)
output1 = RequestOutput(
request_id=request_id,
prompt="Tell me about Paris",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[],
finished=False,
)
output2 = RequestOutput(
request_id=request_id,
prompt="Tell me about Paris",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[],
finished=True,
)
# Feed outputs to queue as they're consumed to avoid aggregation
async def feed_outputs():
queue.put(output1)
await asyncio.sleep(1) # Let first output be consumed
queue.put(output2)
asyncio.create_task(feed_outputs()) # noqa
# Mock add_request to return the queue
async def mock_add_request(*args, **kwargs):
return queue
mock_async_llm.add_request = mock_add_request
# Collect outputs from generate
outputs = []
async for output in mock_async_llm.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id=request_id,
):
outputs.append(output)
assert len(outputs) == 2
assert outputs[0].finished is False
assert outputs[1].finished is True
def make_output(request_id: str, finished: bool) -> RequestOutput:
"""Helper to create a RequestOutput."""
return RequestOutput(
request_id=request_id,
prompt="test",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[],
finished=finished,
)
@pytest.mark.asyncio
async def test_generate_with_async_generator():
"""Test generate with an async input generator.
With the new streaming input API, completion is signaled by finishing
the input generator (not via a resumable flag). Each input chunk
produces intermediate outputs, and the final output has finished=True.
"""
request_id = "test"
sampling_params = SamplingParams(max_tokens=10)
llm = MagicMock(spec=AsyncLLM)
llm.vllm_config = MagicMock()
llm.vllm_config.cache_config.kv_sharing_fast_prefill = False
llm.model_config = MagicMock()
llm.model_config.max_model_len = 2048
llm.log_requests = False
llm.errored = False
llm._pause_cond = asyncio.Condition()
llm._paused = False
llm._run_output_handler = MagicMock()
llm.abort = AsyncMock()
# Bind the real generate method
llm.generate = AsyncLLM.generate.__get__(llm, AsyncLLM)
# Track inputs processed
inputs_received = []
queue = RequestOutputCollector(RequestOutputKind.DELTA, request_id)
async def mock_add_request(req_id, prompt, params, *args, **kwargs):
# When prompt is an AsyncGenerator, process streaming inputs
if isinstance(prompt, AsyncGenerator):
# Process inputs in background, produce outputs
async def handle_stream():
async for input_chunk in prompt:
inputs_received.append(input_chunk.prompt)
# Each input produces an intermediate output
queue.put(make_output(req_id, finished=False))
await asyncio.sleep(0.01)
# Final output when stream ends
queue.put(make_output(req_id, finished=True))
asyncio.create_task(handle_stream())
return queue
return queue
llm.add_request = mock_add_request
async def input_generator() -> AsyncGenerator[StreamingInput, None]:
yield StreamingInput(prompt="Hello", sampling_params=sampling_params)
yield StreamingInput(prompt=" world", sampling_params=sampling_params)
outputs = []
async for output in llm.generate(input_generator(), sampling_params, request_id):
outputs.append(output)
# Two intermediate outputs + one final output
assert len(outputs) == 3
assert outputs[0].finished is False
assert outputs[1].finished is False
assert outputs[2].finished is True
# Both inputs were processed
assert inputs_received == ["Hello", " world"]

View File

@@ -0,0 +1,210 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for GPUModelRunner._update_streaming_request function."""
from unittest.mock import Mock
import pytest
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalKwargsItem,
PlaceholderRange,
)
from vllm.sampling_params import SamplingParams
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
pytestmark = pytest.mark.cpu_test
@pytest.fixture
def mock_model_runner_with_input_batch():
"""Create a mock GPUModelRunner with a real InputBatch for e2e testing."""
runner = Mock(spec=GPUModelRunner)
runner.uses_mrope = False
runner.requests = {}
runner.max_num_reqs = 10
runner.max_model_len = 1024
# Create a real InputBatch for e2e testing
runner.input_batch = InputBatch(
max_num_reqs=10,
max_model_len=1024,
max_num_batched_tokens=1024,
device="cpu",
pin_memory=False,
vocab_size=32000,
block_sizes=[16],
kernel_block_sizes=[16],
is_spec_decode=False,
logitsprocs=None,
is_pooling_model=False,
)
return runner
def test_e2e_streaming_request_update_basic_flow(mock_model_runner_with_input_batch):
"""Test that streaming session are updated correctly.
This test validates that when a streaming session is updated with new prompt tokens:
1. The request is removed from InputBatch before updating (avoids duplication)
2. Request state fields are updated correctly
3. output_token_ids is cleared (intermediate outputs are now in prompt_token_ids)
"""
runner = mock_model_runner_with_input_batch
req_id = "streaming_req_0"
# Step 1: Create initial request state with some computed tokens
initial_req_state = CachedRequestState(
req_id=req_id,
prompt_token_ids=[1, 2, 3],
mm_features=[],
sampling_params=SamplingParams(temperature=0.5),
pooling_params=None,
generator=None,
block_ids=([0],),
num_computed_tokens=3,
output_token_ids=[10, 11], # Generated 2 tokens
)
runner.requests[req_id] = initial_req_state
# Add request to InputBatch
runner.input_batch.add_request(initial_req_state)
assert req_id in runner.input_batch.req_id_to_index
# Step 2: Create new request data with extended prompt
# The scheduler has already set prompt_token_ids to the full sequence
# (original prompt + intermediate outputs + new prompt)
new_req_data = Mock()
new_req_data.prompt_token_ids = [
1,
2,
3,
10,
4,
5,
] # Full sequence with intermediate output (10)
new_req_data.mm_features = []
new_req_data.prompt_embeds = None
new_req_data.sampling_params = SamplingParams(temperature=0.8, max_tokens=50)
new_req_data.pooling_params = None
new_req_data.block_ids = ([0, 1],)
new_req_data.num_computed_tokens = 4 # 3 original prompt + 1 intermediate output
# Step 3: Update the request
updated_req_state = GPUModelRunner._update_streaming_request(
runner, req_id, new_req_data
)
# Step 4: Verify the request state was updated correctly
assert updated_req_state.prompt_token_ids == [1, 2, 3, 10, 4, 5]
assert updated_req_state.num_computed_tokens == 4
assert updated_req_state.sampling_params.temperature == 0.8
assert updated_req_state.sampling_params.max_tokens == 50
assert updated_req_state.block_ids == ([0, 1],)
# Verify output_token_ids were cleared
# (intermediate outputs are now in prompt_token_ids)
assert updated_req_state.output_token_ids == []
# Verify the same object is returned
assert runner.requests[req_id] is updated_req_state
# Verify request was removed from InputBatch during update (avoids duplication)
assert req_id not in runner.input_batch.req_id_to_index
def test_e2e_streaming_with_multimodal_features(mock_model_runner_with_input_batch):
"""Test that streaming session with multimodal features are updated correctly.
This test validates that when a streaming session with mm features is updated:
1. The request is removed from InputBatch before updating (avoids duplication)
2. Multimodal features from both requests are preserved and merged correctly
3. New prompt tokens (including intermediate outputs) are appended correctly
4. output_token_ids is cleared (intermediate outputs are now in prompt_token_ids)
"""
runner = mock_model_runner_with_input_batch
req_id = "streaming_mm_req_0"
# Step 1: Create initial request state with one multimodal feature
mm_feature_1 = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("audio"),
modality="audio",
identifier="audio_1",
mm_position=PlaceholderRange(offset=2, length=10),
)
initial_req_state = CachedRequestState(
req_id=req_id,
prompt_token_ids=[1, 2] + [0] * 10 + [3, 4], # 2 + 10 (mm) + 2 = 14 tokens
mm_features=[mm_feature_1],
sampling_params=SamplingParams(),
pooling_params=None,
generator=None,
block_ids=([0],),
num_computed_tokens=14,
output_token_ids=[100], # Generated 1 token
)
runner.requests[req_id] = initial_req_state
# Add request to InputBatch
runner.input_batch.add_request(initial_req_state)
assert req_id in runner.input_batch.req_id_to_index
# Step 2: Create new request data with additional multimodal feature
# The scheduler has already set prompt_token_ids to the full sequence
# (original prompt + intermediate outputs + new prompt with new multimodal feature)
mm_feature_2 = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("audio"),
modality="audio",
identifier="audio_2",
mm_position=PlaceholderRange(offset=15, length=5),
)
new_req_data = Mock()
# Full sequence: [1, 2] + [0]*10 + [3, 4] + [100] + [0]*5 + [5] = 21 tokens
new_req_data.prompt_token_ids = [1, 2] + [0] * 10 + [3, 4, 100] + [0] * 5 + [5]
new_req_data.mm_features = [mm_feature_1, mm_feature_2]
new_req_data.prompt_embeds = None
new_req_data.sampling_params = SamplingParams(temperature=0.7, max_tokens=30)
new_req_data.pooling_params = None
new_req_data.block_ids = ([0, 1],)
new_req_data.num_computed_tokens = 14 # 14 tokens from initial request
# Step 3: Update the request
updated_req_state = GPUModelRunner._update_streaming_request(
runner, req_id, new_req_data
)
# Step 4: Verify the request state was updated correctly
# Verify multimodal features are preserved
assert len(updated_req_state.mm_features) == 2
assert updated_req_state.mm_features[0] == mm_feature_1
assert updated_req_state.mm_features[1] == mm_feature_2
# Verify prompt tokens include intermediate output (100) and new tokens
# Initial: 2 + 10 (mm1) + 2 = 14 tokens
# New: 2 + 10 (mm1) + 2 + 1 (output 100) + 5 (mm2) + 1 = 21 tokens
assert len(updated_req_state.prompt_token_ids) == 21
assert updated_req_state.prompt_token_ids == [1, 2] + [0] * 10 + [3, 4, 100] + [
0
] * 5 + [5]
# Verify output_token_ids were cleared
# (intermediate outputs are now in prompt_token_ids)
assert updated_req_state.output_token_ids == []
# Verify other parameters were updated
assert updated_req_state.num_computed_tokens == 14
assert updated_req_state.sampling_params.temperature == 0.7
assert updated_req_state.sampling_params.max_tokens == 30
assert updated_req_state.block_ids == ([0, 1],)
# Verify the same object is returned
assert runner.requests[req_id] is updated_req_state
# Verify request was removed from InputBatch during update (avoids duplication)
assert req_id not in runner.input_batch.req_id_to_index

View File

@@ -0,0 +1,575 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import unittest
from unittest.mock import MagicMock
import torch
from vllm.config import DeviceConfig, VllmConfig
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalKwargsItem,
PlaceholderRange,
)
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.engine import FinishReason
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus, StreamingUpdate
from vllm.v1.structured_output import StructuredOutputManager
STOP_TOKEN = 128001
class DummyRequest(Request):
def __init__(
self,
request_id,
resumable=True,
prompt_token_ids=None,
mm_features: list[MultiModalFeatureSpec] | None = None,
max_tokens: int | None = 16,
):
super().__init__(
request_id=request_id,
prompt_token_ids=prompt_token_ids if prompt_token_ids is not None else [],
sampling_params=SamplingParams(
stop_token_ids=[STOP_TOKEN], max_tokens=max_tokens
),
pooling_params=None,
eos_token_id=None,
mm_features=mm_features,
resumable=resumable,
)
def create_scheduler() -> Scheduler:
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"))
vllm_config.model_config = MagicMock()
vllm_config.model_config.skip_tokenizer_init = True
vllm_config.model_config.is_multimodal_model = False
vllm_config.model_config.max_model_len = 1024
vllm_config.model_config.enable_return_routed_experts = False
vllm_config.cache_config = MagicMock()
vllm_config.cache_config.num_gpu_blocks = 1000
vllm_config.cache_config.enable_prefix_caching = False
kv_cache_config = KVCacheConfig(
num_blocks=1000,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer"],
FullAttentionSpec(
block_size=16, num_kv_heads=1, head_size=1, dtype=torch.float32
),
)
],
)
return Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
block_size=16,
)
class TestStreamingScheduler(unittest.TestCase):
def test_add_request(self):
scheduler = create_scheduler()
request = DummyRequest(
request_id="test_request",
resumable=True,
)
scheduler.add_request(request)
assert "test_request" in scheduler.requests
assert request.status == RequestStatus.WAITING
assert len(scheduler.waiting) == 1
next_request = DummyRequest(
request_id="test_request",
resumable=True,
)
scheduler.add_request(next_request)
assert next_request.status == RequestStatus.WAITING
assert len(scheduler.requests["test_request"].streaming_queue) == 1
def test_update_request_as_session_max_token(self):
scheduler = create_scheduler()
session = DummyRequest(
request_id="session",
prompt_token_ids=[1, 2, 3],
)
session.num_computed_tokens = len(session.prompt_token_ids)
session.max_tokens = 10 # Initial max_tokens
session._output_token_ids = [1] * 10 # reach max_tokens
new_request = DummyRequest(
request_id="session",
prompt_token_ids=[4, 5, 6],
)
new_request.sampling_params = SamplingParams(max_tokens=10)
new_request.max_tokens = 10 # Additional max_tokens from new request
update = StreamingUpdate.from_request(new_request)
scheduler._update_request_as_session(session, update)
assert session.sampling_params.max_tokens == 10
# _update_request_as_session clears output tokens first, so
# max_tokens = num_output_tokens (0) + update.max_tokens (10) = 10
assert session.max_tokens == 10
session.num_computed_tokens = len(session.prompt_token_ids)
# Simulate generating 5 more output tokens
session._output_token_ids = [1] * 5
new_request2 = DummyRequest(
request_id="session",
prompt_token_ids=[7, 8, 9],
)
new_request2.sampling_params = SamplingParams(max_tokens=10)
new_request2.max_tokens = 10
update2 = StreamingUpdate.from_request(new_request2)
scheduler._update_request_as_session(session, update2)
assert session.sampling_params.max_tokens == 10
# Again, output tokens are cleared first, so max_tokens = 0 + 10 = 10
assert session.max_tokens == 10
def test_update_request_as_session(self):
scheduler = create_scheduler()
session = DummyRequest(
request_id="session",
prompt_token_ids=[1, 2, 3],
)
session.num_computed_tokens = len(session.prompt_token_ids)
new_request = DummyRequest(
request_id="session",
prompt_token_ids=[4, 5, 6],
)
new_request.sampling_params = SamplingParams(max_tokens=10)
update = StreamingUpdate.from_request(new_request)
scheduler._update_request_as_session(session, update)
assert session.prompt_token_ids == [1, 2, 3, 4, 5, 6]
assert session._all_token_ids == [1, 2, 3, 4, 5, 6]
assert session.sampling_params.max_tokens == 10
assert session.status == RequestStatus.WAITING
def test_update_request_as_session_with_multimodal(self):
scheduler = create_scheduler()
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("audio"),
modality="audio",
identifier="",
mm_position=PlaceholderRange(offset=1, length=1),
)
session = DummyRequest(
request_id="session",
prompt_token_ids=[1, 2, 3],
mm_features=[mm_feature],
)
session.num_computed_tokens = len(session.prompt_token_ids)
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("audio"),
modality="audio",
identifier="",
mm_position=PlaceholderRange(offset=2, length=1),
)
new_request = DummyRequest(
request_id="session",
prompt_token_ids=[4, 5, 6, 7],
mm_features=[mm_feature],
)
update = StreamingUpdate.from_request(new_request)
scheduler._update_request_as_session(session, update)
assert len(session.mm_features) == 2
assert session.mm_features[0].mm_position.offset == 1
# 2 + len([1, 2, 3])
assert session.mm_features[1].mm_position.offset == 5
def test_process_streaming_requests_with_finish_session(self):
"""Test that a non-resumable request signals stream completion.
With the new streaming API, completion is signaled by closing/finishing
the input generator. When a non-resumable request is added to a session
in WAITING_FOR_STREAMING_REQ state, the session is finished immediately
with FINISHED_ABORTED status.
"""
scheduler = create_scheduler()
session = DummyRequest(
request_id="session",
prompt_token_ids=[1, 2, 3],
resumable=True,
)
scheduler.add_request(session)
session.status = RequestStatus.WAITING_FOR_STREAMING_REQ
session.num_computed_tokens = len(session.prompt_token_ids)
# A non-resumable request signals stream completion
close_request = DummyRequest(
request_id="session",
prompt_token_ids=[0],
resumable=False,
max_tokens=1,
)
scheduler.add_request(close_request)
# The session should be immediately finished (stream completed)
assert session.status == RequestStatus.FINISHED_ABORTED
# The session should be removed from the scheduler
assert session.request_id not in scheduler.requests
def test_streaming_request_session_update(self):
"""Test that a resumable request updates a waiting session directly.
When a session is in WAITING_FOR_STREAMING_REQ state and a new resumable
request arrives, the update is applied directly via _update_request_as_session,
not queued.
"""
scheduler = create_scheduler()
session = DummyRequest(
request_id="session",
prompt_token_ids=[1, 2, 3],
resumable=True,
)
scheduler.add_request(session)
session.status = RequestStatus.WAITING_FOR_STREAMING_REQ
session.num_computed_tokens = len(session.prompt_token_ids)
next_request = DummyRequest(
request_id="session",
prompt_token_ids=[4, 5],
resumable=True,
)
scheduler.add_request(next_request)
# With the new behavior, when session is in WAITING_FOR_STREAMING_REQ,
# the update is applied directly (not queued), and session status
# becomes WAITING
assert session.status == RequestStatus.WAITING
assert session.prompt_token_ids == [1, 2, 3, 4, 5]
_ = scheduler.schedule()
assert session.status == RequestStatus.RUNNING
def test_update_request_as_session_with_output_tokens(self):
scheduler = create_scheduler()
session = DummyRequest(
request_id="session",
prompt_token_ids=[1, 2, 3], # 3 prompt tokens
)
session.append_output_token_ids([10, 11])
"""
The last output token (11) hasn't been "scheduled" yet, so `num_computed_tokens`
only includes: 3 prompt + 1 output (the 10) = 4
"""
session.num_computed_tokens = 4
new_request = DummyRequest(
request_id="session",
prompt_token_ids=[4, 5],
)
update = StreamingUpdate.from_request(new_request)
scheduler._update_request_as_session(session, update)
# _update_request_as_session keeps computed output tokens (they become
# part of the prompt) and only discards the final uncomputed sampled
# token. Computed output token 10 is kept, uncomputed token 11 is
# discarded.
assert session._all_token_ids == [1, 2, 3, 10, 4, 5]
assert session.prompt_token_ids == [1, 2, 3, 10, 4, 5]
# Output tokens list is cleared
assert session._output_token_ids == []
# num_computed_tokens is unchanged (KV cache still valid for computed
# tokens)
assert session.num_computed_tokens == 4
# Verify that the next schedule will only process the new prompt tokens
# num_new_tokens = num_tokens - num_computed_tokens = 6 - 4 = 2
num_new_tokens = session.num_tokens - session.num_computed_tokens
assert num_new_tokens == 2
def test_streaming_e2e_lifecycle(self):
"""
Comprehensive integration test covering complete streaming request lifecycle
including scheduler state management and aliasing bug prevention.
FULL LIFECYCLE:
================
CYCLE 1 (Initial Decode):
1. Add streaming request (seq_id=0) with prompt tokens [1,2,3]
2. Schedule() creates NewRequestData with prompt_token_ids
3. Model runner caches this prompt_token_ids reference (simulated)
4. Model executes and generates output token 10
5. update_from_output() appends token 10 to request._all_token_ids
6. Request transitions to RUNNING state
CYCLE 2 (Continue Decode):
7. Schedule() again - request is now in scheduled_cached_reqs (not new)
8. Model runner uses CACHED state to calculate num_tokens
9. Model generates output token (STOP_TOKEN)
10. update_from_output() appends STOP_TOKEN to request._all_token_ids
11. Request transitions to WAITING_FOR_STREAMING_REQ
CYCLE 3 (New Streaming Request):
12. Add new streaming request (seq_id=1) with prompt tokens [4,5]
13. Scheduler merges into session, creates NewRequestData again
14. Model runner caches new prompt_token_ids reference
15. Verify cached state from Cycle 1 wasn't corrupted by mutations
CRITICAL BUG PREVENTION:
========================
Without .copy() in _create_new_request_data():
- Cycle 1 Step 3: cached_state["prompt_token_ids"] aliases
request._all_token_ids
- Cycle 1 Step 5: When appending token 10, cached state mutates:
[1,2,3] -> [1,2,3,10]
- Cycle 2 Step 8: num_tokens = len([1,2,3,10]) + len([10])
= 5 (WRONG! Should be 4)
- Cycle 2: Discard logic would see seq_lens=4 < num_tokens=5
-> INCORRECTLY DISCARDS
With .copy() in _create_new_request_data():
- Cycle 1 Step 3: cached_state["prompt_token_ids"] is independent copy
- Cycle 1 Step 5: Only request._all_token_ids mutates, cached stays [1,2,3]
- Cycle 2 Step 8: num_tokens = len([1,2,3]) + len([10]) = 4 (CORRECT)
- Cycle 2: Discard logic works correctly
"""
scheduler = create_scheduler()
# ═══════════════════════════════════════════════════════════════════
# CYCLE 1: Initial Request Scheduling and First Decode
# ═══════════════════════════════════════════════════════════════════
session = DummyRequest(
request_id="session",
prompt_token_ids=[1, 2, 3],
)
scheduler.add_request(session)
# Step 2: Schedule creates NewRequestData
scheduler_output_cycle1 = scheduler.schedule()
# Verify request is in scheduled_new_reqs (first time scheduling)
assert len(scheduler_output_cycle1.scheduled_new_reqs) == 1
new_req_data_cycle1 = scheduler_output_cycle1.scheduled_new_reqs[0]
assert new_req_data_cycle1.prompt_token_ids == [1, 2, 3]
assert (
scheduler_output_cycle1.num_scheduled_tokens[session.request_id] == 3
) # [1, 2, 3]
assert (
session.request_id
not in scheduler_output_cycle1.scheduled_cached_reqs.req_ids
)
# Step 3: Simulate model runner caching the prompt_token_ids
# This simulates gpu_model_runner.py:706-720 CachedRequestState creation
# The model runner makes a copy of prompt_token_ids when creating
# CachedRequestState
cached_state_cycle1 = {
"req_id": session.request_id,
"prompt_token_ids": list(
new_req_data_cycle1.prompt_token_ids
), # Explicit copy
"output_token_ids": [],
"num_computed_tokens": 0,
}
# Store original for verification
original_cached_prompt_cycle1 = cached_state_cycle1["prompt_token_ids"].copy()
# Step 4-5: Model execution generates token, scheduler updates request
output_token_1 = 10
cached_state_cycle1["output_token_ids"].append(output_token_1)
mro_cycle1 = ModelRunnerOutput(
req_ids=[session.request_id],
req_id_to_index={session.request_id: 0},
sampled_token_ids=[[output_token_1]],
logprobs=None,
prompt_logprobs_dict={session.request_id: None},
pooler_output=[],
)
session.num_computed_tokens = len(session.prompt_token_ids)
eco_dict_cycle1 = scheduler.update_from_output(
scheduler_output_cycle1, mro_cycle1
)
# Step 6: Verify request state after Cycle 1
eco_cycle1 = eco_dict_cycle1[session.client_index].outputs[0]
assert eco_cycle1.finish_reason is None # Not stopped yet
assert session.status == RequestStatus.RUNNING
assert session in scheduler.running
assert session._all_token_ids == [1, 2, 3, 10] # Mutation happened here
# CRITICAL ASSERTION: Cached prompt_token_ids must NOT have changed
assert (
cached_state_cycle1["prompt_token_ids"] == original_cached_prompt_cycle1
), (
f"ALIASING BUG DETECTED in Cycle 1! "
f"cached_state['prompt_token_ids'] was mutated from "
f"{original_cached_prompt_cycle1} to "
f"{cached_state_cycle1['prompt_token_ids']}. "
f"This means _create_new_request_data() didn't call .copy()!"
)
assert cached_state_cycle1["prompt_token_ids"] is not session._all_token_ids, (
"ALIASING BUG! cached_state['prompt_token_ids'] is the same object as "
"session._all_token_ids. They must be independent copies."
)
# ═══════════════════════════════════════════════════════════════════
# CYCLE 2: Continue Decoding (Using Cached State)
# ═══════════════════════════════════════════════════════════════════
# Step 7: Schedule again - now request uses cached state
scheduler_output_cycle2 = scheduler.schedule()
# Verify request is NOT in scheduled_new_reqs (already cached)
assert not scheduler_output_cycle2.scheduled_new_reqs
assert (
session.request_id in scheduler_output_cycle2.scheduled_cached_reqs.req_ids
)
assert (
scheduler_output_cycle2.num_scheduled_tokens[session.request_id] == 1
) # Only the output token [10]
# Step 8: Calculate num_tokens like gpu_model_runner.py:1284 does
# This is where the bug would manifest!
num_tokens_cycle2 = len(cached_state_cycle1["prompt_token_ids"]) + len(
cached_state_cycle1["output_token_ids"]
)
# CRITICAL ASSERTION: num_tokens must be correct (3 prompt + 1 output = 4)
# Without .copy(), cached_state["prompt_token_ids"] would be [1,2,3,10]
# and num_tokens would incorrectly be 5, causing the discard bug
expected_num_tokens_cycle2 = 4
assert num_tokens_cycle2 == expected_num_tokens_cycle2, (
f"DISCARD BUG WOULD TRIGGER! num_tokens calculation is wrong. "
f"Expected {expected_num_tokens_cycle2}, got {num_tokens_cycle2}. "
f"cached_state['prompt_token_ids'] = "
f"{cached_state_cycle1['prompt_token_ids']} (should be [1,2,3], not [1,2,3,"
f"10]). Without .copy(), this would be 5 = len([1,2,3,10]) + len([10]). "
f"Discard logic would see: seq_lens={session.num_computed_tokens} "
f"< num_tokens={num_tokens_cycle2}, triggering incorrect discard!"
)
# Step 9-10: Model generates STOP_TOKEN, scheduler updates
output_token_2 = STOP_TOKEN
cached_state_cycle1["output_token_ids"].append(output_token_2)
mro_cycle2 = ModelRunnerOutput(
req_ids=[session.request_id],
req_id_to_index={session.request_id: 0},
sampled_token_ids=[[output_token_2]],
logprobs=None,
prompt_logprobs_dict={session.request_id: None},
pooler_output=[],
)
eco_dict_cycle2 = scheduler.update_from_output(
scheduler_output_cycle2, mro_cycle2
)
# Step 11: Verify request transitioned to WAITING_FOR_STREAMING_REQ
eco_cycle2 = eco_dict_cycle2[session.client_index].outputs[0]
assert eco_cycle2.finish_reason == FinishReason.STOP
assert session.status == RequestStatus.WAITING_FOR_STREAMING_REQ
assert session in scheduler.waiting
assert session._all_token_ids == [1, 2, 3, 10, STOP_TOKEN]
# CRITICAL ASSERTION: Cached prompt_token_ids STILL must not have changed
assert cached_state_cycle1["prompt_token_ids"] == [1, 2, 3], (
f"ALIASING BUG DETECTED in Cycle 2! "
f"cached_state['prompt_token_ids'] = "
f"{cached_state_cycle1['prompt_token_ids']} (should still be [1,2,3]). "
f"Mutations from update_from_output() leaked through!"
)
# ═══════════════════════════════════════════════════════════════════
# CYCLE 3: New Streaming Request (Session Continuation)
# ═══════════════════════════════════════════════════════════════════
# Step 12: Add new streaming request with seq_id=1
new_request = DummyRequest(
request_id="session",
prompt_token_ids=[4, 5],
)
scheduler.add_request(new_request)
# With the new streaming API, when session is in WAITING_FOR_STREAMING_REQ,
# the update is applied directly via _update_request_as_session (not queued).
# The session status becomes WAITING after the update is applied.
assert session.status == RequestStatus.WAITING
# Step 13: Scheduler schedules the updated session
scheduler_output_cycle3 = scheduler.schedule()
# Verify scheduler created NewRequestData with merged prompt_token_ids
assert len(scheduler_output_cycle3.scheduled_new_reqs) == 1
assert (
scheduler_output_cycle3.scheduled_new_reqs[0].prompt_token_ids
== session.prompt_token_ids
)
assert (
scheduler_output_cycle3.num_scheduled_tokens[session.request_id] == 2
) # Only new tokens [4, 5]
# Computed output tokens are kept (become part of prompt), only the
# final uncomputed sampled token (STOP_TOKEN) is discarded
assert session._all_token_ids == [1, 2, 3, 10, 4, 5]
assert session.prompt_token_ids == [1, 2, 3, 10, 4, 5] # Includes kept output
assert session._output_token_ids == [] # Output tokens are cleared
# Step 14: Model runner caches NEW prompt_token_ids reference
# The model runner makes a copy of prompt_token_ids when creating
# CachedRequestState
new_req_data_cycle3 = scheduler_output_cycle3.scheduled_new_reqs[0]
cached_state_cycle3 = {
"req_id": session.request_id,
"prompt_token_ids": list(
new_req_data_cycle3.prompt_token_ids
), # Explicit copy
"output_token_ids": [],
"num_computed_tokens": session.num_computed_tokens,
}
# Step 15: FINAL CRITICAL VERIFICATION
# The old cached state from Cycle 1 must still be unchanged
assert cached_state_cycle1["prompt_token_ids"] == [1, 2, 3], (
f"PERSISTENT ALIASING BUG! Even after new scheduling cycle, "
f"old cached_state was mutated to "
f"{cached_state_cycle1['prompt_token_ids']}. This proves the aliasing bug "
f"exists!"
)
# The new cached state must be independent
assert cached_state_cycle3["prompt_token_ids"] is not session._all_token_ids, (
"ALIASING BUG in Cycle 3! Cached state is aliased to _all_token_ids."
)
# Both cached states must be independent of each other
assert (
cached_state_cycle1["prompt_token_ids"]
is not cached_state_cycle3["prompt_token_ids"]
), "Cached states from different cycles should be independent objects."