[Metrics] Refactor LoRA state tracking (#26801)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
@@ -15,12 +15,19 @@ from tests.v1.engine.utils import (
|
||||
)
|
||||
from vllm import PoolingParams
|
||||
from vllm.logprobs import PromptLogprobs, SampleLogprobs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine import (
|
||||
EngineCoreEvent,
|
||||
EngineCoreEventType,
|
||||
EngineCoreOutputs,
|
||||
EngineCoreRequest,
|
||||
FinishReason,
|
||||
)
|
||||
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||
|
||||
|
||||
def _ref_convert_id_to_token(
|
||||
@@ -895,6 +902,170 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
assert iteration_stats.num_generation_tokens == num_active
|
||||
|
||||
|
||||
@pytest.mark.parametrize("log_stats", [True, False])
|
||||
def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
|
||||
"""Test LoRA request lifecycle tracking through waiting -> running -> finished."""
|
||||
output_processor = OutputProcessor(
|
||||
dummy_test_vectors.tokenizer, log_stats=log_stats
|
||||
)
|
||||
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
|
||||
engine_core_timestamp = time.monotonic()
|
||||
|
||||
# Create LoRA requests
|
||||
lora1 = LoRARequest(lora_name="lora-1", lora_int_id=1, lora_path="/path/to/lora1")
|
||||
lora2 = LoRARequest(lora_name="lora-2", lora_int_id=2, lora_path="/path/to/lora2")
|
||||
|
||||
# Create requests with different LoRA adapters:
|
||||
# - request-0: lora-1
|
||||
# - request-1: lora-2
|
||||
# - request-2: None (no LoRA)
|
||||
lora_assignments = [lora1, lora2, None]
|
||||
requests = [
|
||||
EngineCoreRequest(
|
||||
request_id=f"request-{idx}",
|
||||
prompt_token_ids=prompt_tokens,
|
||||
mm_features=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0,
|
||||
lora_request=lora_assignments[idx],
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
sampling_params=SamplingParams(),
|
||||
pooling_params=None,
|
||||
)
|
||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
# Add all requests to the OutputProcessor
|
||||
for request in requests:
|
||||
output_processor.add_request(request, None)
|
||||
|
||||
# First iteration: process outputs with QUEUED events
|
||||
outputs = EngineCoreOutputs(
|
||||
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
|
||||
)
|
||||
for output in outputs.outputs:
|
||||
output.events = [
|
||||
EngineCoreEvent.new_event(EngineCoreEventType.QUEUED, engine_core_timestamp)
|
||||
]
|
||||
|
||||
iteration_stats = IterationStats() if log_stats else None
|
||||
output_processor.process_outputs(
|
||||
outputs.outputs, engine_core_timestamp, iteration_stats
|
||||
)
|
||||
output_processor.update_scheduler_stats(outputs.scheduler_stats)
|
||||
|
||||
if log_stats:
|
||||
# Verify waiting counts
|
||||
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-1") == 1
|
||||
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-2") == 1
|
||||
assert outputs.scheduler_stats.running_lora_adapters.get("lora-1") == 0
|
||||
assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 0
|
||||
# Verify internal state
|
||||
assert len(output_processor.lora_states.requests) == 2
|
||||
assert "lora-1" in output_processor.lora_states.requests
|
||||
assert "lora-2" in output_processor.lora_states.requests
|
||||
else:
|
||||
# When log_stats=False, no tracking should occur
|
||||
assert iteration_stats is None
|
||||
assert len(output_processor.lora_states.requests) == 0
|
||||
|
||||
# Second iteration: process outputs with SCHEDULED events
|
||||
outputs = EngineCoreOutputs(
|
||||
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
|
||||
)
|
||||
for output in outputs.outputs:
|
||||
output.events = [
|
||||
EngineCoreEvent.new_event(
|
||||
EngineCoreEventType.SCHEDULED, engine_core_timestamp
|
||||
)
|
||||
]
|
||||
|
||||
iteration_stats = IterationStats() if log_stats else None
|
||||
output_processor.process_outputs(
|
||||
outputs.outputs, engine_core_timestamp, iteration_stats
|
||||
)
|
||||
output_processor.update_scheduler_stats(outputs.scheduler_stats)
|
||||
|
||||
if log_stats:
|
||||
# Verify running counts
|
||||
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-1") == 0
|
||||
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-2") == 0
|
||||
assert outputs.scheduler_stats.running_lora_adapters.get("lora-1") == 1
|
||||
assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 1
|
||||
else:
|
||||
assert iteration_stats is None
|
||||
assert len(output_processor.lora_states.requests) == 0
|
||||
|
||||
# Third iteration: finish request-0 (lora-1)
|
||||
outputs = EngineCoreOutputs(
|
||||
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
|
||||
)
|
||||
# Find and mark request-0 as finished (it uses lora-1)
|
||||
for output in outputs.outputs:
|
||||
if output.request_id == "request-0":
|
||||
output.finish_reason = FinishReason.LENGTH
|
||||
break
|
||||
|
||||
iteration_stats = IterationStats() if log_stats else None
|
||||
output_processor.process_outputs(
|
||||
outputs.outputs, engine_core_timestamp, iteration_stats
|
||||
)
|
||||
output_processor.update_scheduler_stats(outputs.scheduler_stats)
|
||||
|
||||
if log_stats:
|
||||
# lora-1 should be removed since no requests remain
|
||||
assert "lora-1" not in output_processor.lora_states.requests
|
||||
# lora-2 should still be running
|
||||
assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 1
|
||||
assert len(output_processor.lora_states.requests) == 1
|
||||
else:
|
||||
assert len(output_processor.lora_states.requests) == 0
|
||||
|
||||
# Fourth iteration: finish request-1 (lora-2)
|
||||
outputs = EngineCoreOutputs(
|
||||
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
|
||||
)
|
||||
# Find and mark request-1 as finished (it uses lora-2)
|
||||
for output in outputs.outputs:
|
||||
if output.request_id == "request-1":
|
||||
output.finish_reason = FinishReason.LENGTH
|
||||
break
|
||||
|
||||
iteration_stats = IterationStats() if log_stats else None
|
||||
output_processor.process_outputs(
|
||||
outputs.outputs, engine_core_timestamp, iteration_stats
|
||||
)
|
||||
output_processor.update_scheduler_stats(outputs.scheduler_stats)
|
||||
|
||||
if log_stats:
|
||||
# lora-2 should be removed since no requests remain
|
||||
assert "lora-2" not in output_processor.lora_states.requests
|
||||
assert len(outputs.scheduler_stats.running_lora_adapters) == 0
|
||||
assert len(output_processor.lora_states.requests) == 0
|
||||
else:
|
||||
assert len(output_processor.lora_states.requests) == 0
|
||||
|
||||
# Finish the last request (no LoRA)
|
||||
outputs = EngineCoreOutputs(
|
||||
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
|
||||
)
|
||||
# Find and mark request-2 as finished (it has no LoRA)
|
||||
for output in outputs.outputs:
|
||||
if output.request_id == "request-2":
|
||||
output.finish_reason = FinishReason.LENGTH
|
||||
break
|
||||
|
||||
iteration_stats = IterationStats() if log_stats else None
|
||||
output_processor.process_outputs(
|
||||
outputs.outputs, engine_core_timestamp, iteration_stats
|
||||
)
|
||||
output_processor.update_scheduler_stats(outputs.scheduler_stats)
|
||||
|
||||
# Verify all requests are finished
|
||||
assert output_processor.get_num_unfinished_requests() == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_output_collector():
|
||||
NUM_REQS = 3
|
||||
|
||||
Reference in New Issue
Block a user