[Core] Cleanup engine pause/sleep logic (#34528)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -3,8 +3,10 @@
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from contextlib import ExitStack
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -187,24 +189,33 @@ async def test_load(
|
||||
# =============================================================================
|
||||
# DP Pause/Resume Tests
|
||||
# =============================================================================
|
||||
# When expert_parallel=False: uses non-MoE model (DP replicas as separate engines).
|
||||
# When expert_parallel=True: uses MoE model + EP (DPEngineCoreProc, sync pause path).
|
||||
|
||||
DP_PAUSE_MODEL = "hmellor/tiny-random-LlamaForCausalLM"
|
||||
DP_PAUSE_MODEL_MOE = "ibm-research/PowerMoE-3b"
|
||||
DP_PAUSE_PROMPT = "This is a test of data parallel pause"
|
||||
|
||||
|
||||
def _get_dp_pause_engine_args(expert_parallel: bool) -> AsyncEngineArgs:
|
||||
"""Engine args for DP pause tests: MoE+EP when expert_parallel else small Llama."""
|
||||
model = DP_PAUSE_MODEL_MOE if expert_parallel else DP_PAUSE_MODEL
|
||||
return AsyncEngineArgs(
|
||||
model=model,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
|
||||
data_parallel_size=DP_SIZE,
|
||||
data_parallel_backend="mp",
|
||||
enable_expert_parallel=expert_parallel,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dp_pause_resume_basic():
|
||||
@pytest.mark.parametrize("expert_parallel", [False, True])
|
||||
async def test_dp_pause_resume_basic(expert_parallel: bool):
|
||||
"""Pausing from the client (one call) pauses all DP ranks; resume clears it."""
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip("DP pause tests use mp backend only")
|
||||
with ExitStack() as after:
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=DP_PAUSE_MODEL,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
|
||||
data_parallel_size=DP_SIZE,
|
||||
data_parallel_backend="mp",
|
||||
)
|
||||
engine_args = _get_dp_pause_engine_args(expert_parallel)
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
@@ -226,18 +237,11 @@ async def test_dp_pause_resume_basic():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dp_pause_abort():
|
||||
@pytest.mark.parametrize("expert_parallel", [False, True])
|
||||
async def test_dp_pause_abort(expert_parallel: bool):
|
||||
"""Pause with abort from one client aborts in-flight requests on all DP ranks."""
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip("DP pause tests use mp backend only")
|
||||
with ExitStack() as after:
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=DP_PAUSE_MODEL,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
|
||||
data_parallel_size=DP_SIZE,
|
||||
data_parallel_backend="mp",
|
||||
)
|
||||
engine_args = _get_dp_pause_engine_args(expert_parallel)
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
@@ -286,41 +290,111 @@ async def test_dp_pause_abort():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dp_pause_keep_then_resume():
|
||||
"""Pause with keep queues new requests; resume allows them to run."""
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip("DP pause tests use mp backend only")
|
||||
@pytest.mark.parametrize("expert_parallel", [False, True])
|
||||
async def test_dp_pause_keep_then_resume(expert_parallel: bool):
|
||||
"""Start generation, pause after a few tokens (keep mode), resume; verify gap."""
|
||||
|
||||
pause_duration = 2.0
|
||||
min_tokens_before_pause = 3
|
||||
|
||||
with ExitStack() as after:
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=DP_PAUSE_MODEL,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
|
||||
data_parallel_size=DP_SIZE,
|
||||
data_parallel_backend="mp",
|
||||
)
|
||||
engine_args = _get_dp_pause_engine_args(expert_parallel)
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
await engine.pause_generation(mode="keep")
|
||||
assert await engine.is_paused()
|
||||
sampling_params = SamplingParams(max_tokens=15, ignore_eos=True)
|
||||
token_times: list[tuple[int, float]] = []
|
||||
pause_token_idx = 0
|
||||
|
||||
request_done = asyncio.Event()
|
||||
|
||||
async def gen():
|
||||
async for out in engine.generate(
|
||||
request_id="queued-keep",
|
||||
async def generator_task():
|
||||
nonlocal pause_token_idx
|
||||
out = None
|
||||
async for output in engine.generate(
|
||||
request_id="keep-resume-req",
|
||||
prompt=DP_PAUSE_PROMPT,
|
||||
sampling_params=SamplingParams(max_tokens=5),
|
||||
sampling_params=sampling_params,
|
||||
):
|
||||
pass
|
||||
request_done.set()
|
||||
token_count = len(output.outputs[0].token_ids)
|
||||
token_times.append((token_count, time.monotonic()))
|
||||
out = output
|
||||
return out
|
||||
|
||||
task = asyncio.create_task(gen())
|
||||
await asyncio.sleep(0.2)
|
||||
assert not request_done.is_set()
|
||||
async def controller_task():
|
||||
nonlocal pause_token_idx
|
||||
while len(token_times) < min_tokens_before_pause:
|
||||
await asyncio.sleep(0.01)
|
||||
await engine.pause_generation(mode="keep")
|
||||
await asyncio.sleep(pause_duration)
|
||||
pause_token_idx = len(token_times)
|
||||
await engine.resume_generation()
|
||||
|
||||
gen_task = asyncio.create_task(generator_task())
|
||||
ctrl_task = asyncio.create_task(controller_task())
|
||||
final_output, _ = await asyncio.gather(gen_task, ctrl_task)
|
||||
|
||||
assert final_output is not None and final_output.finished
|
||||
assert await engine.is_paused() is False
|
||||
assert pause_token_idx >= min_tokens_before_pause
|
||||
if pause_token_idx > 0 and pause_token_idx < len(token_times):
|
||||
pause_gap = (
|
||||
token_times[pause_token_idx][1] - token_times[pause_token_idx - 1][1]
|
||||
)
|
||||
assert pause_gap >= pause_duration * 0.8, (
|
||||
f"Expected gap ~{pause_duration}s after pause, got {pause_gap:.3f}s"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dp_pause_keep_race_staggered_engines():
|
||||
"""Race: send pause(keep) to engine 0, then add two requests,
|
||||
then pause(keep) to engine 1. Ensures no deadlock when pause
|
||||
requests are staggered and requests arrive in between."""
|
||||
if DP_SIZE != 2:
|
||||
pytest.skip("test_dp_pause_keep_race_staggered_engines requires DP_SIZE=2")
|
||||
|
||||
with ExitStack() as after:
|
||||
engine_args = _get_dp_pause_engine_args(expert_parallel=True)
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
client = engine.engine_core
|
||||
|
||||
original_call_utility = client.call_utility_async
|
||||
mid_pause_tasks: list[asyncio.Task] = []
|
||||
|
||||
async def staggered_pause_keep(method: str, *args) -> Any:
|
||||
if method != "pause_scheduler" or not args or args[0] != "keep":
|
||||
return await original_call_utility(method, *args)
|
||||
# Send pause(keep) to engine 0 first
|
||||
await client._call_utility_async(
|
||||
method, *args, engine=client.core_engines[0]
|
||||
)
|
||||
# In the middle: send two requests (race window)
|
||||
sp = SamplingParams(max_tokens=5, ignore_eos=True)
|
||||
|
||||
async def consume_gen(req_id: str) -> None:
|
||||
async for _ in engine.generate(
|
||||
request_id=req_id,
|
||||
prompt=DP_PAUSE_PROMPT,
|
||||
sampling_params=sp,
|
||||
):
|
||||
pass
|
||||
|
||||
t1 = asyncio.create_task(consume_gen("race-1"))
|
||||
t2 = asyncio.create_task(consume_gen("race-2"))
|
||||
mid_pause_tasks.extend([t1, t2])
|
||||
await asyncio.sleep(3)
|
||||
# Then send pause(keep) to engine 1
|
||||
result = await client._call_utility_async(
|
||||
method, *args, engine=client.core_engines[1]
|
||||
)
|
||||
return result
|
||||
|
||||
client.call_utility_async = staggered_pause_keep
|
||||
|
||||
await engine.pause_generation(mode="keep")
|
||||
assert await engine.is_paused()
|
||||
await engine.resume_generation()
|
||||
final = await asyncio.wait_for(task, timeout=10.0)
|
||||
assert final.finished
|
||||
assert not await engine.is_paused()
|
||||
# Let the two requests we sent mid-pause complete
|
||||
await asyncio.gather(*mid_pause_tasks)
|
||||
|
||||
@@ -280,20 +280,15 @@ def echo_dc_nested(
|
||||
|
||||
|
||||
def future_echo(self, value: Any, num_wait_loops: int = 2) -> Future:
|
||||
"""Utility that returns a Future completed by a per_step_hook after
|
||||
num_wait_loops engine steps (tests deferred utility path).
|
||||
"""Utility that returns a Future completed once the engine is idle
|
||||
(tests deferred utility path).
|
||||
"""
|
||||
future: Future = Future()
|
||||
remaining = [num_wait_loops]
|
||||
|
||||
def _step(engine: EngineCore) -> bool:
|
||||
remaining[0] -= 1
|
||||
if remaining[0] <= 0:
|
||||
future.set_result(value)
|
||||
return True # remove hook
|
||||
return False
|
||||
def idle(engine: EngineCore):
|
||||
future.set_result(value)
|
||||
|
||||
self.per_step_hooks.add(_step)
|
||||
self._idle_state_callbacks.append(idle)
|
||||
return future
|
||||
|
||||
|
||||
@@ -832,8 +827,8 @@ async def test_engine_core_client_future_utility_async(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
subprocess_future_echo_patch,
|
||||
):
|
||||
"""Test that a utility returning a Future (completed by a per_step_hook
|
||||
after N steps) completes when the future is done (engine uses add_done_callback).
|
||||
"""Test that a utility returning a Future completes when the future is done
|
||||
(engine uses add_done_callback).
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(EngineCore, "future_echo", future_echo, raising=False)
|
||||
|
||||
@@ -148,7 +148,7 @@ class EngineClient(ABC):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def sleep(self, level: int = 1) -> None:
|
||||
async def sleep(self, level: int = 1, mode: "PauseMode" = "abort") -> None:
|
||||
"""Sleep the engine"""
|
||||
...
|
||||
|
||||
|
||||
@@ -87,6 +87,7 @@ from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.counter import Counter
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
from vllm.utils.tqdm_utils import maybe_tqdm
|
||||
from vllm.v1.engine import PauseMode
|
||||
from vllm.v1.engine.llm_engine import LLMEngine
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||
|
||||
@@ -441,8 +442,7 @@ class LLM:
|
||||
A list of `RequestOutput` objects containing the
|
||||
generated completions in the same order as the input prompts.
|
||||
"""
|
||||
model_config = self.model_config
|
||||
runner_type = model_config.runner_type
|
||||
runner_type = self.model_config.runner_type
|
||||
if runner_type != "generate":
|
||||
raise ValueError(
|
||||
"LLM.generate() is only supported for generative models. "
|
||||
@@ -489,46 +489,22 @@ class LLM:
|
||||
Returns:
|
||||
A list of request IDs for the enqueued requests.
|
||||
"""
|
||||
model_config = self.model_config
|
||||
runner_type = model_config.runner_type
|
||||
runner_type = self.model_config.runner_type
|
||||
if runner_type != "generate":
|
||||
raise ValueError("LLM.enqueue() is only supported for generative models.")
|
||||
|
||||
if sampling_params is None:
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
|
||||
# Use the same preprocessing as _run_completion
|
||||
seq_prompts = prompt_to_seq(prompts)
|
||||
seq_params = self._params_to_seq(sampling_params, len(seq_prompts))
|
||||
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
|
||||
seq_tok_kwargs = [
|
||||
merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
|
||||
)
|
||||
for param in seq_params
|
||||
]
|
||||
seq_priority = self._priority_to_seq(priority, len(prompts))
|
||||
|
||||
request_ids = self._render_and_add_requests(
|
||||
prompts=(
|
||||
self._preprocess_cmpl_one(prompt, tok_kwargs)
|
||||
for prompt, tok_kwargs in zip(
|
||||
maybe_tqdm(
|
||||
seq_prompts,
|
||||
use_tqdm=use_tqdm,
|
||||
desc="Rendering prompts",
|
||||
),
|
||||
seq_tok_kwargs,
|
||||
)
|
||||
),
|
||||
params=seq_params,
|
||||
lora_requests=seq_lora_requests,
|
||||
priorities=seq_priority,
|
||||
return self._add_completion_requests(
|
||||
prompts=prompts,
|
||||
params=sampling_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
priority=priority,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
return request_ids
|
||||
|
||||
@overload
|
||||
def wait_for_completion(
|
||||
self,
|
||||
@@ -1659,7 +1635,7 @@ class LLM:
|
||||
reset_running_requests, reset_connector
|
||||
)
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
def sleep(self, level: int = 1, mode: PauseMode = "abort"):
|
||||
"""
|
||||
Put the engine to sleep. The engine should not process any requests.
|
||||
The caller should guarantee that no requests are being processed
|
||||
@@ -1679,10 +1655,10 @@ class LLM:
|
||||
a different model or update the model, where
|
||||
previous model weights are not needed. It reduces
|
||||
CPU memory pressure.
|
||||
mode: How to handle any existing requests, can be "abort", "wait",
|
||||
or "keep".
|
||||
"""
|
||||
if level > 0:
|
||||
self.reset_prefix_cache()
|
||||
self.llm_engine.sleep(level=level)
|
||||
self.llm_engine.sleep(level=level, mode=mode)
|
||||
|
||||
def wake_up(self, tags: list[str] | None = None):
|
||||
"""
|
||||
@@ -1759,6 +1735,45 @@ class LLM:
|
||||
|
||||
return [0] * num_requests
|
||||
|
||||
def _add_completion_requests(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
params: SamplingParams
|
||||
| PoolingParams
|
||||
| Sequence[SamplingParams | PoolingParams],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
|
||||
priority: list[int] | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> list[str]:
|
||||
seq_prompts = prompt_to_seq(prompts)
|
||||
seq_params = self._params_to_seq(params, len(seq_prompts))
|
||||
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
|
||||
seq_tok_kwargs = [
|
||||
merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
|
||||
)
|
||||
for param in seq_params
|
||||
]
|
||||
seq_priority = self._priority_to_seq(priority, len(prompts))
|
||||
|
||||
return self._render_and_add_requests(
|
||||
prompts=(
|
||||
self._preprocess_cmpl_one(prompt, tok_kwargs)
|
||||
for prompt, tok_kwargs in zip(
|
||||
maybe_tqdm(
|
||||
seq_prompts, use_tqdm=use_tqdm, desc="Rendering prompts"
|
||||
),
|
||||
seq_tok_kwargs,
|
||||
)
|
||||
),
|
||||
params=seq_params,
|
||||
lora_requests=seq_lora_requests,
|
||||
priorities=seq_priority,
|
||||
)
|
||||
|
||||
def _run_completion(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
@@ -1772,36 +1787,15 @@ class LLM:
|
||||
priority: list[int] | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
):
|
||||
seq_prompts = prompt_to_seq(prompts)
|
||||
seq_params = self._params_to_seq(params, len(seq_prompts))
|
||||
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
|
||||
seq_tok_kwargs = [
|
||||
merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
|
||||
)
|
||||
for param in seq_params
|
||||
]
|
||||
seq_priority = self._priority_to_seq(priority, len(prompts))
|
||||
|
||||
return self._render_and_run_requests(
|
||||
prompts=(
|
||||
self._preprocess_cmpl_one(prompt, tok_kwargs)
|
||||
for prompt, tok_kwargs in zip(
|
||||
maybe_tqdm(
|
||||
seq_prompts,
|
||||
use_tqdm=use_tqdm,
|
||||
desc="Rendering prompts",
|
||||
),
|
||||
seq_tok_kwargs,
|
||||
)
|
||||
),
|
||||
params=seq_params,
|
||||
output_type=output_type,
|
||||
self._add_completion_requests(
|
||||
prompts=prompts,
|
||||
params=params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_requests=seq_lora_requests,
|
||||
priorities=seq_priority,
|
||||
lora_request=lora_request,
|
||||
priority=priority,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
return self._run_engine(use_tqdm=use_tqdm, output_type=output_type)
|
||||
|
||||
def _run_chat(
|
||||
self,
|
||||
|
||||
@@ -23,7 +23,8 @@ router = APIRouter()
|
||||
async def sleep(raw_request: Request):
|
||||
# get POST params
|
||||
level = raw_request.query_params.get("level", "1")
|
||||
await engine_client(raw_request).sleep(int(level))
|
||||
mode = raw_request.query_params.get("mode", "abort")
|
||||
await engine_client(raw_request).sleep(int(level), mode)
|
||||
# FIXME: in v0 with frontend multiprocessing, the sleep command
|
||||
# is sent but does not finish yet when we return a response.
|
||||
return Response(status_code=200)
|
||||
|
||||
@@ -753,6 +753,13 @@ class AsyncLLM(EngineClient):
|
||||
)
|
||||
mode = "wait"
|
||||
await self.engine_core.pause_scheduler_async(mode=mode, clear_cache=clear_cache)
|
||||
# Small sleep to help ensure that final outputs from any in-flight requests are
|
||||
# returned prior to this method returning. These outputs come out of the engine
|
||||
# prior to the wait-for-idle completion event, but involve additional async
|
||||
# tasks in output processing.
|
||||
# Note that this is not required for correctness, just more intuitive ordering
|
||||
# of events from caller's pov.
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
async def resume_generation(self) -> None:
|
||||
"""Resume generation after :meth:`pause_generation`."""
|
||||
@@ -890,10 +897,8 @@ class AsyncLLM(EngineClient):
|
||||
async def reset_encoder_cache(self) -> None:
|
||||
await self.engine_core.reset_encoder_cache_async()
|
||||
|
||||
async def sleep(self, level: int = 1) -> None:
|
||||
if level > 0:
|
||||
await self.reset_prefix_cache()
|
||||
await self.engine_core.sleep_async(level)
|
||||
async def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None:
|
||||
await self.engine_core.sleep_async(level, mode)
|
||||
|
||||
if self.logger_manager is not None:
|
||||
self.logger_manager.record_sleep_state(1, level)
|
||||
|
||||
@@ -9,6 +9,7 @@ from collections import defaultdict, deque
|
||||
from collections.abc import Callable, Generator
|
||||
from concurrent.futures import Future
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from functools import partial
|
||||
from inspect import isclass, signature
|
||||
from logging import DEBUG
|
||||
from typing import Any, TypeVar, cast
|
||||
@@ -211,7 +212,7 @@ class EngineCore:
|
||||
|
||||
self.aborts_queue = queue.Queue[list[str]]()
|
||||
|
||||
self.per_step_hooks: set[Callable] = set()
|
||||
self._idle_state_callbacks: list[Callable] = []
|
||||
|
||||
# Mark the startup heap as static so that it's ignored by GC.
|
||||
# Reduces pause times of oldest generation collections.
|
||||
@@ -592,21 +593,51 @@ class EngineCore:
|
||||
# Reset the GPU model runner's encoder cache (physical storage)
|
||||
self.model_executor.reset_encoder_cache()
|
||||
|
||||
def _reset_caches(self, reset_running_requests=True) -> None:
|
||||
self.reset_prefix_cache(reset_running_requests=reset_running_requests)
|
||||
self.reset_mm_cache()
|
||||
self.reset_encoder_cache()
|
||||
|
||||
def pause_scheduler(
|
||||
self, mode: PauseMode = "abort", clear_cache: bool = True
|
||||
) -> Future[Any] | None:
|
||||
"""Pause scheduling. No-op in base EngineCore; overridden in EngineCoreProc."""
|
||||
) -> Future | None:
|
||||
"""Pause generation; behavior depends on mode.
|
||||
|
||||
All pause modes queue new adds -- "abort" and "keep" skip step();
|
||||
"wait" allows step() so in-flight requests can drain.
|
||||
|
||||
- ``abort``: Set PAUSED_NEW, abort all requests, wait for abort
|
||||
outputs to be sent (when running with output_queue), optionally
|
||||
clear caches, then complete the returned Future.
|
||||
- ``wait``: Set PAUSED_NEW (queue adds, keep stepping); when drained,
|
||||
optionally clear caches, then complete the returned Future.
|
||||
- ``keep``: Set PAUSED_ALL; return a Future that completes when the
|
||||
output queue is empty.
|
||||
"""
|
||||
if mode not in ("keep", "abort", "wait"):
|
||||
raise ValueError(f"Invalid pause mode: {mode}")
|
||||
if mode == "wait":
|
||||
raise ValueError("'wait' mode can't be used in inproc-engine mode")
|
||||
|
||||
if mode == "abort":
|
||||
self.scheduler.finish_requests(None, RequestStatus.FINISHED_ABORTED)
|
||||
|
||||
pause_state = PauseState.PAUSED_ALL if mode == "keep" else PauseState.PAUSED_NEW
|
||||
self.scheduler.set_pause_state(pause_state)
|
||||
if clear_cache:
|
||||
self._reset_caches()
|
||||
|
||||
return None
|
||||
|
||||
def resume_scheduler(self) -> None:
|
||||
"""Resume scheduling. No-op in base EngineCore; overridden in EngineCoreProc."""
|
||||
"""Resume the scheduler and flush any requests queued while paused."""
|
||||
self.scheduler.set_pause_state(PauseState.UNPAUSED)
|
||||
|
||||
def is_scheduler_paused(self) -> bool:
|
||||
"""Return whether the scheduler is in any pause state. False in base EngineCore
|
||||
and overridden in EngineCoreProc."""
|
||||
return False
|
||||
"""Return whether the scheduler is in any pause state."""
|
||||
return self.scheduler.pause_state != PauseState.UNPAUSED
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None | Future:
|
||||
"""Put the engine to sleep at the specified level.
|
||||
|
||||
Args:
|
||||
@@ -615,13 +646,34 @@ class EngineCore:
|
||||
but not processed. No GPU memory changes.
|
||||
- Level 1: Offload model weights to CPU, discard KV cache.
|
||||
- Level 2: Discard all GPU memory.
|
||||
mode: Pause mode - how to deal with any existing requests, see
|
||||
documentation of pause_scheduler method.
|
||||
"""
|
||||
if level == 0:
|
||||
# Level 0: Just pause scheduling, don't touch GPU
|
||||
self.pause_scheduler()
|
||||
else:
|
||||
# Level 1+: Delegate to executor for GPU memory management
|
||||
self.model_executor.sleep(level)
|
||||
|
||||
# Pause scheduler before sleeping.
|
||||
clear_prefix_cache = level >= 1
|
||||
pause_future = self.pause_scheduler(mode=mode, clear_cache=clear_prefix_cache)
|
||||
if level < 1:
|
||||
return pause_future
|
||||
|
||||
# Level 1+: Delegate to executor for GPU memory management
|
||||
model_executor = self.model_executor
|
||||
if pause_future is None:
|
||||
model_executor.sleep(level)
|
||||
return None
|
||||
|
||||
future = Future[Any]()
|
||||
|
||||
def pause_complete(f: Future):
|
||||
try:
|
||||
f.result() # propagate any exception
|
||||
future.set_result(model_executor.sleep(level))
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
|
||||
logger.info("Waiting for in-flight requests to complete before sleeping...")
|
||||
pause_future.add_done_callback(pause_complete)
|
||||
return future
|
||||
|
||||
def wake_up(self, tags: list[str] | None = None):
|
||||
"""Wake up the engine from sleep.
|
||||
@@ -630,17 +682,15 @@ class EngineCore:
|
||||
tags: Tags to wake up. Use ["scheduling"] for level 0 wake up.
|
||||
"""
|
||||
if tags is not None and "scheduling" in tags:
|
||||
# Level 0 wake up: Resume scheduling
|
||||
self.resume_scheduler()
|
||||
# Remove "scheduling" from tags if there are other tags to process
|
||||
remaining_tags = [t for t in tags if t != "scheduling"]
|
||||
if remaining_tags:
|
||||
self.model_executor.wake_up(remaining_tags)
|
||||
else:
|
||||
# Full wake up
|
||||
self.resume_scheduler()
|
||||
# Remove "scheduling" from tags if there are other tags to process.
|
||||
tags = [t for t in tags if t != "scheduling"]
|
||||
|
||||
if tags is None or tags:
|
||||
self.model_executor.wake_up(tags)
|
||||
|
||||
# Resume scheduling (applies to all levels)
|
||||
self.resume_scheduler()
|
||||
|
||||
def is_sleeping(self) -> bool:
|
||||
"""Check if engine is sleeping at any level."""
|
||||
return self.is_scheduler_paused() or self.model_executor.is_sleeping
|
||||
@@ -1038,6 +1088,14 @@ class EngineCoreProc(EngineCore):
|
||||
def _init_data_parallel(self, vllm_config: VllmConfig):
|
||||
pass
|
||||
|
||||
def has_work(self) -> bool:
|
||||
"""Returns true if the engine should be stepped."""
|
||||
return (
|
||||
self.engines_running
|
||||
or self.scheduler.has_requests()
|
||||
or bool(self.batch_queue)
|
||||
)
|
||||
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore."""
|
||||
|
||||
@@ -1047,19 +1105,14 @@ class EngineCoreProc(EngineCore):
|
||||
self._process_input_queue()
|
||||
# 2) Step the engine core and return the outputs.
|
||||
self._process_engine_step()
|
||||
# 3) Run any per-step hooks.
|
||||
self._process_per_step_hooks()
|
||||
|
||||
def _process_input_queue(self):
|
||||
"""Exits when an engine step needs to be performed."""
|
||||
|
||||
waited = False
|
||||
while (
|
||||
not self.engines_running
|
||||
and not self.scheduler.has_requests()
|
||||
and not self.batch_queue
|
||||
and not self.per_step_hooks
|
||||
):
|
||||
while not self.has_work():
|
||||
# Notify callbacks waiting for engine to become idle.
|
||||
self._notify_idle_state_callbacks()
|
||||
if self.input_queue.empty():
|
||||
# Drain aborts queue; all aborts are also processed via input_queue.
|
||||
with self.aborts_queue.mutex:
|
||||
@@ -1098,12 +1151,10 @@ class EngineCoreProc(EngineCore):
|
||||
|
||||
return model_executed
|
||||
|
||||
def _process_per_step_hooks(self) -> None:
|
||||
if self.per_step_hooks:
|
||||
for hook in list(self.per_step_hooks):
|
||||
finished = hook(self)
|
||||
if finished:
|
||||
self.per_step_hooks.discard(hook)
|
||||
def _notify_idle_state_callbacks(self) -> None:
|
||||
while self._idle_state_callbacks:
|
||||
callback = self._idle_state_callbacks.pop()
|
||||
callback(self)
|
||||
|
||||
def _handle_client_request(
|
||||
self, request_type: EngineCoreRequestType, request: Any
|
||||
@@ -1377,19 +1428,10 @@ class EngineCoreProc(EngineCore):
|
||||
if mode not in ("keep", "abort", "wait"):
|
||||
raise ValueError(f"Invalid pause mode: {mode}")
|
||||
|
||||
future: Future[Any] = Future()
|
||||
|
||||
def wait_until_idle(engine: "EngineCoreProc") -> bool:
|
||||
scheduler = engine.scheduler
|
||||
out_queue = engine.output_queue
|
||||
if scheduler.has_requests() or engine.batch_queue or not out_queue.empty():
|
||||
return False
|
||||
def engine_idle_callback(engine: "EngineCoreProc", future: Future[Any]) -> None:
|
||||
if clear_cache:
|
||||
engine.reset_prefix_cache(reset_running_requests=True)
|
||||
engine.reset_mm_cache()
|
||||
engine.reset_encoder_cache()
|
||||
engine._reset_caches()
|
||||
future.set_result(None)
|
||||
return True
|
||||
|
||||
if mode == "abort":
|
||||
aborted_reqs = self.scheduler.finish_requests(
|
||||
@@ -1399,12 +1441,17 @@ class EngineCoreProc(EngineCore):
|
||||
|
||||
pause_state = PauseState.PAUSED_ALL if mode == "keep" else PauseState.PAUSED_NEW
|
||||
self.scheduler.set_pause_state(pause_state)
|
||||
if not wait_until_idle(self):
|
||||
self.per_step_hooks.add(wait_until_idle)
|
||||
return future
|
||||
return None
|
||||
if not self.has_work():
|
||||
if clear_cache:
|
||||
self._reset_caches()
|
||||
return None
|
||||
|
||||
future = Future[Any]()
|
||||
self._idle_state_callbacks.append(partial(engine_idle_callback, future=future))
|
||||
return future
|
||||
|
||||
def _send_abort_outputs(self, aborted_reqs: list[tuple[str, int]]) -> None:
|
||||
# TODO(nick) this will be moved inside the scheduler
|
||||
if aborted_reqs:
|
||||
# Map client_index to list of request_ids that belong to that client.
|
||||
by_client = defaultdict[int, set[str]](set)
|
||||
@@ -1418,14 +1465,6 @@ class EngineCoreProc(EngineCore):
|
||||
eco = EngineCoreOutputs(finished_requests=req_ids, outputs=outputs)
|
||||
self.output_queue.put_nowait((client_index, eco))
|
||||
|
||||
def resume_scheduler(self) -> None:
|
||||
"""Resume the scheduler and flush any requests queued while paused."""
|
||||
self.scheduler.set_pause_state(PauseState.UNPAUSED)
|
||||
|
||||
def is_scheduler_paused(self) -> bool:
|
||||
"""Return whether the scheduler is in any pause state."""
|
||||
return self.scheduler.pause_state != PauseState.UNPAUSED
|
||||
|
||||
|
||||
class DPEngineCoreProc(EngineCoreProc):
|
||||
"""ZMQ-wrapper for running EngineCore in background process
|
||||
@@ -1481,6 +1520,7 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||
|
||||
def add_request(self, request: Request, request_wave: int = 0):
|
||||
super().add_request(request, request_wave)
|
||||
if self.has_coordinator and request_wave != self.current_wave:
|
||||
if request_wave > self.current_wave:
|
||||
self.current_wave = request_wave
|
||||
@@ -1491,7 +1531,13 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
(-1, EngineCoreOutputs(start_wave=self.current_wave))
|
||||
)
|
||||
|
||||
super().add_request(request, request_wave)
|
||||
def resume_scheduler(self):
|
||||
super().resume_scheduler()
|
||||
if not self.engines_running and self.scheduler.has_unfinished_requests():
|
||||
# Wake up other DP engines.
|
||||
self.output_queue.put_nowait(
|
||||
(-1, EngineCoreOutputs(start_wave=self.current_wave))
|
||||
)
|
||||
|
||||
def _handle_client_request(
|
||||
self, request_type: EngineCoreRequestType, request: Any
|
||||
@@ -1532,8 +1578,8 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
# 2) Step the engine core.
|
||||
executed = self._process_engine_step()
|
||||
self._maybe_publish_request_counts()
|
||||
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
|
||||
|
||||
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
|
||||
if not executed:
|
||||
if not local_unfinished_reqs and not self.engines_running:
|
||||
# All engines are idle.
|
||||
|
||||
@@ -150,7 +150,7 @@ class EngineCoreClient(ABC):
|
||||
def reset_encoder_cache(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def wake_up(self, tags: list[str] | None = None) -> None:
|
||||
@@ -227,7 +227,7 @@ class EngineCoreClient(ABC):
|
||||
async def reset_encoder_cache_async(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def sleep_async(self, level: int = 1) -> None:
|
||||
async def sleep_async(self, level: int = 1, mode: PauseMode = "abort") -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def wake_up_async(self, tags: list[str] | None = None) -> None:
|
||||
@@ -314,8 +314,11 @@ class InprocClient(EngineCoreClient):
|
||||
def reset_encoder_cache(self) -> None:
|
||||
self.engine_core.reset_encoder_cache()
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
self.engine_core.sleep(level)
|
||||
def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None:
|
||||
if mode == "wait":
|
||||
raise ValueError("'wait' pause mode is not supported in inproc-engine mode")
|
||||
result = self.engine_core.sleep(level, mode)
|
||||
assert result is None
|
||||
|
||||
def wake_up(self, tags: list[str] | None = None) -> None:
|
||||
self.engine_core.wake_up(tags)
|
||||
@@ -796,8 +799,8 @@ class SyncMPClient(MPClient):
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.call_utility("pin_lora", lora_id)
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
self.call_utility("sleep", level)
|
||||
def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None:
|
||||
self.call_utility("sleep", level, mode)
|
||||
|
||||
def wake_up(self, tags: list[str] | None = None) -> None:
|
||||
self.call_utility("wake_up", tags)
|
||||
@@ -1009,8 +1012,8 @@ class AsyncMPClient(MPClient):
|
||||
async def reset_encoder_cache_async(self) -> None:
|
||||
await self.call_utility_async("reset_encoder_cache")
|
||||
|
||||
async def sleep_async(self, level: int = 1) -> None:
|
||||
await self.call_utility_async("sleep", level)
|
||||
async def sleep_async(self, level: int = 1, mode: PauseMode = "abort") -> None:
|
||||
await self.call_utility_async("sleep", level, mode)
|
||||
|
||||
async def wake_up_async(self, tags: list[str] | None = None) -> None:
|
||||
await self.call_utility_async("wake_up", tags)
|
||||
|
||||
@@ -28,7 +28,7 @@ from vllm.tasks import SupportedTask
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tracing import init_tracer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine import EngineCoreRequest, PauseMode
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.input_processor import InputProcessor
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
@@ -355,8 +355,8 @@ class LLMEngine:
|
||||
"""
|
||||
self.engine_core.reset_encoder_cache()
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
self.engine_core.sleep(level)
|
||||
def sleep(self, level: int = 1, mode: PauseMode = "abort"):
|
||||
self.engine_core.sleep(level, mode)
|
||||
|
||||
if self.logger_manager is not None:
|
||||
self.logger_manager.record_sleep_state(1, level)
|
||||
|
||||
@@ -429,8 +429,6 @@ class OutputProcessor:
|
||||
self.external_req_ids: defaultdict[str, list[str]] = defaultdict(list)
|
||||
self.lora_states = LoRARequestStates(log_stats)
|
||||
self.tracing_enabled = tracing_enabled
|
||||
self._requests_drained = asyncio.Event()
|
||||
self._requests_drained.set()
|
||||
|
||||
def get_num_unfinished_requests(self):
|
||||
return len(self.request_states)
|
||||
@@ -438,11 +436,6 @@ class OutputProcessor:
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
return len(self.request_states) > 0
|
||||
|
||||
async def wait_for_requests_to_drain(self) -> None:
|
||||
if not self.request_states:
|
||||
return
|
||||
await self._requests_drained.wait()
|
||||
|
||||
def propagate_error(self, e: Exception):
|
||||
"""Propagate error to all generate() tasks."""
|
||||
|
||||
@@ -510,8 +503,6 @@ class OutputProcessor:
|
||||
child_reqs = self.abort_requests(child_reqs, internal=True)
|
||||
request_ids_to_abort.extend(child_reqs)
|
||||
self.parent_requests.pop(request_id, None)
|
||||
if not self.request_states:
|
||||
self._requests_drained.set()
|
||||
return request_ids_to_abort
|
||||
|
||||
def add_request(
|
||||
@@ -538,8 +529,6 @@ class OutputProcessor:
|
||||
log_stats=self.log_stats,
|
||||
stream_interval=self.stream_interval,
|
||||
)
|
||||
if self._requests_drained.is_set():
|
||||
self._requests_drained.clear()
|
||||
self.request_states[request_id] = req_state
|
||||
if parent_req:
|
||||
self.parent_requests[parent_req.request_id] = parent_req
|
||||
@@ -706,9 +695,6 @@ class OutputProcessor:
|
||||
if parent_req and not parent_req.child_requests:
|
||||
self.parent_requests.pop(parent_req.request_id, None)
|
||||
|
||||
if not self.request_states:
|
||||
self._requests_drained.set()
|
||||
|
||||
def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
|
||||
self.lora_states.update_scheduler_stats(scheduler_stats)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user