[V1] Optimize handling of sampling metadata and req_ids list (#13244)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -5,6 +5,7 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData,
|
||||
SchedulerOutput)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
|
||||
@@ -82,14 +83,21 @@ def _is_req_added(model_runner, req_id: str) -> bool:
|
||||
return req_id in model_runner.requests
|
||||
|
||||
|
||||
def _is_sampling_metadata_changed(model_runner,
|
||||
sampling_metadata_before: SamplingMetadata):
|
||||
return model_runner.input_batch.sampling_metadata is not (
|
||||
sampling_metadata_before)
|
||||
|
||||
|
||||
def test_update_states_new_request(model_runner):
|
||||
req_id = "req_0"
|
||||
|
||||
# new req
|
||||
scheduler_output = _schedule_new_request(req_id)
|
||||
|
||||
batch_changed = model_runner._update_states(scheduler_output)
|
||||
assert batch_changed is True
|
||||
metadata_before = model_runner.input_batch.sampling_metadata
|
||||
model_runner._update_states(scheduler_output)
|
||||
assert _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert _is_req_scheduled(model_runner, req_id)
|
||||
|
||||
@@ -117,8 +125,9 @@ def test_update_states_request_finished(model_runner):
|
||||
free_encoder_input_ids=[],
|
||||
)
|
||||
|
||||
batch_changed = model_runner._update_states(scheduler_output)
|
||||
assert batch_changed is True
|
||||
metadata_before = model_runner.input_batch.sampling_metadata
|
||||
model_runner._update_states(scheduler_output)
|
||||
assert _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||
assert not _is_req_added(model_runner, req_id)
|
||||
assert not _is_req_scheduled(model_runner, req_id)
|
||||
|
||||
@@ -142,7 +151,7 @@ def test_update_states_request_resumed(model_runner):
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids={},
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
)
|
||||
|
||||
@@ -171,8 +180,9 @@ def test_update_states_request_resumed(model_runner):
|
||||
free_encoder_input_ids=[],
|
||||
)
|
||||
|
||||
batch_changed = model_runner._update_states(scheduler_output)
|
||||
assert batch_changed is True
|
||||
metadata_before = model_runner.input_batch.sampling_metadata
|
||||
model_runner._update_states(scheduler_output)
|
||||
assert _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert _is_req_scheduled(model_runner, req_id)
|
||||
|
||||
@@ -200,8 +210,9 @@ def test_update_states_no_changes(model_runner):
|
||||
free_encoder_input_ids=[],
|
||||
)
|
||||
|
||||
batch_changed = model_runner._update_states(scheduler_output)
|
||||
assert batch_changed is False
|
||||
metadata_before = model_runner.input_batch.sampling_metadata
|
||||
model_runner._update_states(scheduler_output)
|
||||
assert not _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert _is_req_scheduled(model_runner, req_id)
|
||||
|
||||
@@ -233,8 +244,8 @@ def test_update_states_request_unscheduled(model_runner):
|
||||
free_encoder_input_ids=[],
|
||||
)
|
||||
|
||||
batch_changed = model_runner._update_states(scheduler_output)
|
||||
assert batch_changed is True
|
||||
metadata_before = model_runner._update_states(scheduler_output)
|
||||
assert _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||
|
||||
assert _is_req_added(model_runner, req_ids[0])
|
||||
assert _is_req_scheduled(model_runner, req_ids[0])
|
||||
|
||||
Reference in New Issue
Block a user