[Core] Add sleep level 0 mode with enqueue/wait pattern (#33195)

Signed-off-by: Jaewon Lee <jaewon@meta.com>
Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
This commit is contained in:
Jaewon
2026-02-12 16:16:25 -08:00
committed by GitHub
parent be7370daf3
commit aa181c923b
5 changed files with 167 additions and 26 deletions

View File

@@ -458,6 +458,93 @@ class LLM:
return self.engine_class.validate_outputs(outputs, RequestOutput) return self.engine_class.validate_outputs(outputs, RequestOutput)
def enqueue(
self,
prompts: PromptType | Sequence[PromptType],
sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[str]:
"""Enqueue prompts for generation without waiting for completion.
This method adds requests to the engine queue but does not start
processing them. Use wait_for_completion() to process the queued
requests and get results.
Args:
prompts: The prompts to the LLM. See generate() for details.
sampling_params: The sampling parameters for text generation.
lora_request: LoRA request to use for generation, if any.
priority: The priority of the requests, if any.
use_tqdm: If True, shows a tqdm progress bar while adding requests.
tokenization_kwargs: Overrides for `tokenizer.encode`.
Returns:
A list of request IDs for the enqueued requests.
"""
model_config = self.model_config
runner_type = model_config.runner_type
if runner_type != "generate":
raise ValueError("LLM.enqueue() is only supported for generative models.")
if sampling_params is None:
sampling_params = self.get_default_sampling_params()
# Use the same preprocessing as _run_completion
seq_prompts = prompt_to_seq(prompts)
seq_params = self._params_to_seq(sampling_params, len(seq_prompts))
if any(param.truncate_prompt_tokens is not None for param in seq_params):
engine_prompts: Sequence[DictPrompt | TokPrompt] = [
engine_prompt
for prompt, param in zip(seq_prompts, seq_params)
for engine_prompt in self._preprocess_completion(
[prompt],
tokenization_kwargs=merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
),
)
]
else:
engine_prompts = self._preprocess_completion(
seq_prompts,
tokenization_kwargs=tokenization_kwargs,
)
request_ids = self._validate_and_add_requests(
prompts=engine_prompts,
params=seq_params,
use_tqdm=use_tqdm,
lora_request=self._get_modality_specific_lora_reqs(
engine_prompts, lora_request
),
tokenization_kwargs=tokenization_kwargs,
priority=priority,
)
return request_ids
def wait_for_completion(
self,
use_tqdm: bool | Callable[..., tqdm] = True,
) -> list[RequestOutput]:
"""Wait for all enqueued requests to complete and return results.
This method processes all requests currently in the engine queue
and returns their outputs. Use after enqueue() to get results.
Args:
use_tqdm: If True, shows a tqdm progress bar.
Returns:
A list of RequestOutput objects for all completed requests.
"""
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, RequestOutput)
def _get_modality_specific_lora_reqs( def _get_modality_specific_lora_reqs(
self, self,
prompts: Sequence[DictPrompt | TokPrompt], prompts: Sequence[DictPrompt | TokPrompt],
@@ -1618,18 +1705,21 @@ class LLM:
during the sleep period, before `wake_up` is called. during the sleep period, before `wake_up` is called.
Args: Args:
level: The sleep level. Level 1 sleep will offload the model level: The sleep level.
weights and discard the kv cache. The content of kv cache - Level 0: Pause scheduling but continue accepting requests.
is forgotten. Level 1 sleep is good for sleeping and waking Requests are queued but not processed.
up the engine to run the same model again. The model weights - Level 1: Offload model weights to CPU, discard KV cache.
are backed up in CPU memory. Please make sure there's enough The content of kv cache is forgotten. Good for
CPU memory to store the model weights. Level 2 sleep will sleeping and waking up the engine to run the same
discard both the model weights and the kv cache. The content model again. Please make sure there's enough CPU
of both the model weights and kv cache is forgotten. Level 2 memory to store the model weights.
sleep is good for sleeping and waking up the engine to run a - Level 2: Discard all GPU memory (weights + KV cache).
different model or update the model, where previous model Good for sleeping and waking up the engine to run
weights are not needed. It reduces CPU memory pressure. a different model or update the model, where
previous model weights are not needed. It reduces
CPU memory pressure.
""" """
if level > 0:
self.reset_prefix_cache() self.reset_prefix_cache()
self.llm_engine.sleep(level=level) self.llm_engine.sleep(level=level)
@@ -1641,9 +1731,10 @@ class LLM:
Args: Args:
tags: An optional list of tags to reallocate the engine memory tags: An optional list of tags to reallocate the engine memory
for specific memory allocations. Values must be in for specific memory allocations. Values must be in
`("weights", "kv_cache")`. If None, all memory is reallocated. `("weights", "kv_cache", "scheduling")`. If None, all memory
wake_up should be called with all tags (or None) before the is reallocated. wake_up should be called with all tags
engine is used again. (or None) before the engine is used again.
Use tags=["scheduling"] to resume from level 0 sleep.
""" """
self.llm_engine.wake_up(tags) self.llm_engine.wake_up(tags)
@@ -1810,7 +1901,7 @@ class LLM:
lora_request: Sequence[LoRARequest | None] | LoRARequest | None, lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
priority: list[int] | None = None, priority: list[int] | None = None,
) -> None: ) -> list[str]:
num_requests = len(prompts) num_requests = len(prompts)
seq_params = self._params_to_seq(params, num_requests) seq_params = self._params_to_seq(params, num_requests)
seq_lora_requests = self._lora_request_to_seq(lora_request, num_requests) seq_lora_requests = self._lora_request_to_seq(lora_request, num_requests)
@@ -1844,6 +1935,8 @@ class LLM:
self.llm_engine.abort_request(added_request_ids, internal=True) self.llm_engine.abort_request(added_request_ids, internal=True)
raise e raise e
return added_request_ids
def _add_request( def _add_request(
self, self,
prompt: PromptType | DictPrompt | TokPrompt, prompt: PromptType | DictPrompt | TokPrompt,
@@ -1895,7 +1988,9 @@ class LLM:
return engine_request.request_id return engine_request.request_id
def _run_engine( def _run_engine(
self, *, use_tqdm: bool | Callable[..., tqdm] = True self,
*,
use_tqdm: bool | Callable[..., tqdm] = True,
) -> list[RequestOutput | PoolingRequestOutput]: ) -> list[RequestOutput | PoolingRequestOutput]:
# Initialize tqdm. # Initialize tqdm.
if use_tqdm: if use_tqdm:

View File

@@ -938,6 +938,7 @@ class AsyncLLM(EngineClient):
await self.engine_core.reset_encoder_cache_async() await self.engine_core.reset_encoder_cache_async()
async def sleep(self, level: int = 1) -> None: async def sleep(self, level: int = 1) -> None:
if level > 0:
await self.reset_prefix_cache() await self.reset_prefix_cache()
await self.engine_core.sleep_async(level) await self.engine_core.sleep_async(level)

View File

@@ -614,13 +614,43 @@ class EngineCore:
self.model_executor.reset_encoder_cache() self.model_executor.reset_encoder_cache()
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
"""Put the engine to sleep at the specified level.
Args:
level: Sleep level.
- Level 0: Pause scheduling only. Requests are still accepted
but not processed. No GPU memory changes.
- Level 1: Offload model weights to CPU, discard KV cache.
- Level 2: Discard all GPU memory.
"""
if level == 0:
# Level 0: Just pause scheduling, don't touch GPU
self.pause_scheduler()
else:
# Level 1+: Delegate to executor for GPU memory management
self.model_executor.sleep(level) self.model_executor.sleep(level)
def wake_up(self, tags: list[str] | None = None): def wake_up(self, tags: list[str] | None = None):
"""Wake up the engine from sleep.
Args:
tags: Tags to wake up. Use ["scheduling"] for level 0 wake up.
"""
if tags is not None and "scheduling" in tags:
# Level 0 wake up: Resume scheduling
self.resume_scheduler()
# Remove "scheduling" from tags if there are other tags to process
remaining_tags = [t for t in tags if t != "scheduling"]
if remaining_tags:
self.model_executor.wake_up(remaining_tags)
else:
# Full wake up
self.resume_scheduler()
self.model_executor.wake_up(tags) self.model_executor.wake_up(tags)
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self.model_executor.is_sleeping """Check if engine is sleeping at any level."""
return self._scheduler_paused or self.model_executor.is_sleeping
def execute_dummy_batch(self): def execute_dummy_batch(self):
self.model_executor.execute_dummy_batch() self.model_executor.execute_dummy_batch()
@@ -1023,7 +1053,13 @@ class EngineCoreProc(EngineCore):
# 1) Poll the input queue until there is work to do. # 1) Poll the input queue until there is work to do.
self._process_input_queue() self._process_input_queue()
# 2) Step the engine core and return the outputs. # 2) Step the engine core and return the outputs.
# Skip if scheduling is paused (level 0 sleep)
if not self._scheduler_paused:
self._process_engine_step() self._process_engine_step()
else:
# When scheduling is paused, still need to check for wake up
# by processing any utility requests that might resume scheduling
pass
def _process_input_queue(self): def _process_input_queue(self):
"""Exits when an engine step needs to be performed.""" """Exits when an engine step needs to be performed."""
@@ -1031,7 +1067,7 @@ class EngineCoreProc(EngineCore):
waited = False waited = False
while ( while (
not self.engines_running not self.engines_running
and not self.scheduler.has_requests() and (not self.scheduler.has_requests() or self._scheduler_paused)
and not self.batch_queue and not self.batch_queue
and not self._scheduler_paused and not self._scheduler_paused
): ):
@@ -1414,11 +1450,15 @@ class DPEngineCoreProc(EngineCoreProc):
# 1) Poll the input queue until there is work to do. # 1) Poll the input queue until there is work to do.
self._process_input_queue() self._process_input_queue()
# Skip processing if scheduling is paused (level 0 sleep)
if self._scheduler_paused:
continue
# 2) Step the engine core. # 2) Step the engine core.
executed = self._process_engine_step() executed = self._process_engine_step()
self._maybe_publish_request_counts() self._maybe_publish_request_counts()
local_unfinished_reqs = self.scheduler.has_unfinished_requests() local_unfinished_reqs = self.scheduler.has_unfinished_requests()
if not executed: if not executed:
if not local_unfinished_reqs and not self.engines_running: if not local_unfinished_reqs and not self.engines_running:
# All engines are idle. # All engines are idle.

View File

@@ -194,7 +194,7 @@ class EngineCoreClient(ABC):
raise NotImplementedError raise NotImplementedError
def dp_engines_running(self) -> bool: def dp_engines_running(self) -> bool:
"""Returns True id data parallel engines are collectively in a """Returns True if data parallel engines are collectively in a
running state.""" running state."""
raise NotImplementedError raise NotImplementedError
@@ -724,6 +724,7 @@ class SyncMPClient(MPClient):
# it is forwarded to the outputs_queue so we can raise it # it is forwarded to the outputs_queue so we can raise it
# from this (run_output_handler) task to shut down the server. # from this (run_output_handler) task to shut down the server.
outputs = self.outputs_queue.get() outputs = self.outputs_queue.get()
if isinstance(outputs, Exception): if isinstance(outputs, Exception):
raise self._format_exception(outputs) from None raise self._format_exception(outputs) from None
if outputs.wave_complete is not None: if outputs.wave_complete is not None:

View File

@@ -312,7 +312,11 @@ class LLMEngine:
# 4) Record stats # 4) Record stats
with record_function_or_nullcontext("llm_engine step: record_stats"): with record_function_or_nullcontext("llm_engine step: record_stats"):
if self.logger_manager is not None and outputs.scheduler_stats is not None: if (
self.logger_manager is not None
and outputs.scheduler_stats is not None
and len(outputs.outputs) > 0
):
self.logger_manager.record( self.logger_manager.record(
scheduler_stats=outputs.scheduler_stats, scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats, iteration_stats=iteration_stats,