[Core] Add sleep level 0 mode with enqueue/wait pattern (#33195)
Signed-off-by: Jaewon Lee <jaewon@meta.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
This commit is contained in:
@@ -458,6 +458,93 @@ class LLM:
|
||||
|
||||
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
||||
|
||||
def enqueue(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
priority: list[int] | None = None,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> list[str]:
|
||||
"""Enqueue prompts for generation without waiting for completion.
|
||||
|
||||
This method adds requests to the engine queue but does not start
|
||||
processing them. Use wait_for_completion() to process the queued
|
||||
requests and get results.
|
||||
|
||||
Args:
|
||||
prompts: The prompts to the LLM. See generate() for details.
|
||||
sampling_params: The sampling parameters for text generation.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
priority: The priority of the requests, if any.
|
||||
use_tqdm: If True, shows a tqdm progress bar while adding requests.
|
||||
tokenization_kwargs: Overrides for `tokenizer.encode`.
|
||||
|
||||
Returns:
|
||||
A list of request IDs for the enqueued requests.
|
||||
"""
|
||||
model_config = self.model_config
|
||||
runner_type = model_config.runner_type
|
||||
if runner_type != "generate":
|
||||
raise ValueError("LLM.enqueue() is only supported for generative models.")
|
||||
|
||||
if sampling_params is None:
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
|
||||
# Use the same preprocessing as _run_completion
|
||||
seq_prompts = prompt_to_seq(prompts)
|
||||
seq_params = self._params_to_seq(sampling_params, len(seq_prompts))
|
||||
|
||||
if any(param.truncate_prompt_tokens is not None for param in seq_params):
|
||||
engine_prompts: Sequence[DictPrompt | TokPrompt] = [
|
||||
engine_prompt
|
||||
for prompt, param in zip(seq_prompts, seq_params)
|
||||
for engine_prompt in self._preprocess_completion(
|
||||
[prompt],
|
||||
tokenization_kwargs=merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
|
||||
),
|
||||
)
|
||||
]
|
||||
else:
|
||||
engine_prompts = self._preprocess_completion(
|
||||
seq_prompts,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
request_ids = self._validate_and_add_requests(
|
||||
prompts=engine_prompts,
|
||||
params=seq_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=self._get_modality_specific_lora_reqs(
|
||||
engine_prompts, lora_request
|
||||
),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
return request_ids
|
||||
|
||||
def wait_for_completion(
|
||||
self,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
) -> list[RequestOutput]:
|
||||
"""Wait for all enqueued requests to complete and return results.
|
||||
|
||||
This method processes all requests currently in the engine queue
|
||||
and returns their outputs. Use after enqueue() to get results.
|
||||
|
||||
Args:
|
||||
use_tqdm: If True, shows a tqdm progress bar.
|
||||
|
||||
Returns:
|
||||
A list of RequestOutput objects for all completed requests.
|
||||
"""
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
||||
|
||||
def _get_modality_specific_lora_reqs(
|
||||
self,
|
||||
prompts: Sequence[DictPrompt | TokPrompt],
|
||||
@@ -1618,19 +1705,22 @@ class LLM:
|
||||
during the sleep period, before `wake_up` is called.
|
||||
|
||||
Args:
|
||||
level: The sleep level. Level 1 sleep will offload the model
|
||||
weights and discard the kv cache. The content of kv cache
|
||||
is forgotten. Level 1 sleep is good for sleeping and waking
|
||||
up the engine to run the same model again. The model weights
|
||||
are backed up in CPU memory. Please make sure there's enough
|
||||
CPU memory to store the model weights. Level 2 sleep will
|
||||
discard both the model weights and the kv cache. The content
|
||||
of both the model weights and kv cache is forgotten. Level 2
|
||||
sleep is good for sleeping and waking up the engine to run a
|
||||
different model or update the model, where previous model
|
||||
weights are not needed. It reduces CPU memory pressure.
|
||||
level: The sleep level.
|
||||
- Level 0: Pause scheduling but continue accepting requests.
|
||||
Requests are queued but not processed.
|
||||
- Level 1: Offload model weights to CPU, discard KV cache.
|
||||
The content of kv cache is forgotten. Good for
|
||||
sleeping and waking up the engine to run the same
|
||||
model again. Please make sure there's enough CPU
|
||||
memory to store the model weights.
|
||||
- Level 2: Discard all GPU memory (weights + KV cache).
|
||||
Good for sleeping and waking up the engine to run
|
||||
a different model or update the model, where
|
||||
previous model weights are not needed. It reduces
|
||||
CPU memory pressure.
|
||||
"""
|
||||
self.reset_prefix_cache()
|
||||
if level > 0:
|
||||
self.reset_prefix_cache()
|
||||
self.llm_engine.sleep(level=level)
|
||||
|
||||
def wake_up(self, tags: list[str] | None = None):
|
||||
@@ -1641,9 +1731,10 @@ class LLM:
|
||||
Args:
|
||||
tags: An optional list of tags to reallocate the engine memory
|
||||
for specific memory allocations. Values must be in
|
||||
`("weights", "kv_cache")`. If None, all memory is reallocated.
|
||||
wake_up should be called with all tags (or None) before the
|
||||
engine is used again.
|
||||
`("weights", "kv_cache", "scheduling")`. If None, all memory
|
||||
is reallocated. wake_up should be called with all tags
|
||||
(or None) before the engine is used again.
|
||||
Use tags=["scheduling"] to resume from level 0 sleep.
|
||||
"""
|
||||
self.llm_engine.wake_up(tags)
|
||||
|
||||
@@ -1810,7 +1901,7 @@ class LLM:
|
||||
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
priority: list[int] | None = None,
|
||||
) -> None:
|
||||
) -> list[str]:
|
||||
num_requests = len(prompts)
|
||||
seq_params = self._params_to_seq(params, num_requests)
|
||||
seq_lora_requests = self._lora_request_to_seq(lora_request, num_requests)
|
||||
@@ -1844,6 +1935,8 @@ class LLM:
|
||||
self.llm_engine.abort_request(added_request_ids, internal=True)
|
||||
raise e
|
||||
|
||||
return added_request_ids
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
prompt: PromptType | DictPrompt | TokPrompt,
|
||||
@@ -1895,7 +1988,9 @@ class LLM:
|
||||
return engine_request.request_id
|
||||
|
||||
def _run_engine(
|
||||
self, *, use_tqdm: bool | Callable[..., tqdm] = True
|
||||
self,
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
) -> list[RequestOutput | PoolingRequestOutput]:
|
||||
# Initialize tqdm.
|
||||
if use_tqdm:
|
||||
|
||||
@@ -938,7 +938,8 @@ class AsyncLLM(EngineClient):
|
||||
await self.engine_core.reset_encoder_cache_async()
|
||||
|
||||
async def sleep(self, level: int = 1) -> None:
|
||||
await self.reset_prefix_cache()
|
||||
if level > 0:
|
||||
await self.reset_prefix_cache()
|
||||
await self.engine_core.sleep_async(level)
|
||||
|
||||
if self.logger_manager is not None:
|
||||
|
||||
@@ -614,13 +614,43 @@ class EngineCore:
|
||||
self.model_executor.reset_encoder_cache()
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
self.model_executor.sleep(level)
|
||||
"""Put the engine to sleep at the specified level.
|
||||
|
||||
Args:
|
||||
level: Sleep level.
|
||||
- Level 0: Pause scheduling only. Requests are still accepted
|
||||
but not processed. No GPU memory changes.
|
||||
- Level 1: Offload model weights to CPU, discard KV cache.
|
||||
- Level 2: Discard all GPU memory.
|
||||
"""
|
||||
if level == 0:
|
||||
# Level 0: Just pause scheduling, don't touch GPU
|
||||
self.pause_scheduler()
|
||||
else:
|
||||
# Level 1+: Delegate to executor for GPU memory management
|
||||
self.model_executor.sleep(level)
|
||||
|
||||
def wake_up(self, tags: list[str] | None = None):
|
||||
self.model_executor.wake_up(tags)
|
||||
"""Wake up the engine from sleep.
|
||||
|
||||
Args:
|
||||
tags: Tags to wake up. Use ["scheduling"] for level 0 wake up.
|
||||
"""
|
||||
if tags is not None and "scheduling" in tags:
|
||||
# Level 0 wake up: Resume scheduling
|
||||
self.resume_scheduler()
|
||||
# Remove "scheduling" from tags if there are other tags to process
|
||||
remaining_tags = [t for t in tags if t != "scheduling"]
|
||||
if remaining_tags:
|
||||
self.model_executor.wake_up(remaining_tags)
|
||||
else:
|
||||
# Full wake up
|
||||
self.resume_scheduler()
|
||||
self.model_executor.wake_up(tags)
|
||||
|
||||
def is_sleeping(self) -> bool:
|
||||
return self.model_executor.is_sleeping
|
||||
"""Check if engine is sleeping at any level."""
|
||||
return self._scheduler_paused or self.model_executor.is_sleeping
|
||||
|
||||
def execute_dummy_batch(self):
|
||||
self.model_executor.execute_dummy_batch()
|
||||
@@ -1023,7 +1053,13 @@ class EngineCoreProc(EngineCore):
|
||||
# 1) Poll the input queue until there is work to do.
|
||||
self._process_input_queue()
|
||||
# 2) Step the engine core and return the outputs.
|
||||
self._process_engine_step()
|
||||
# Skip if scheduling is paused (level 0 sleep)
|
||||
if not self._scheduler_paused:
|
||||
self._process_engine_step()
|
||||
else:
|
||||
# When scheduling is paused, still need to check for wake up
|
||||
# by processing any utility requests that might resume scheduling
|
||||
pass
|
||||
|
||||
def _process_input_queue(self):
|
||||
"""Exits when an engine step needs to be performed."""
|
||||
@@ -1031,7 +1067,7 @@ class EngineCoreProc(EngineCore):
|
||||
waited = False
|
||||
while (
|
||||
not self.engines_running
|
||||
and not self.scheduler.has_requests()
|
||||
and (not self.scheduler.has_requests() or self._scheduler_paused)
|
||||
and not self.batch_queue
|
||||
and not self._scheduler_paused
|
||||
):
|
||||
@@ -1414,11 +1450,15 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
# 1) Poll the input queue until there is work to do.
|
||||
self._process_input_queue()
|
||||
|
||||
# Skip processing if scheduling is paused (level 0 sleep)
|
||||
if self._scheduler_paused:
|
||||
continue
|
||||
|
||||
# 2) Step the engine core.
|
||||
executed = self._process_engine_step()
|
||||
self._maybe_publish_request_counts()
|
||||
|
||||
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
|
||||
|
||||
if not executed:
|
||||
if not local_unfinished_reqs and not self.engines_running:
|
||||
# All engines are idle.
|
||||
|
||||
@@ -194,7 +194,7 @@ class EngineCoreClient(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
def dp_engines_running(self) -> bool:
|
||||
"""Returns True id data parallel engines are collectively in a
|
||||
"""Returns True if data parallel engines are collectively in a
|
||||
running state."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -724,6 +724,7 @@ class SyncMPClient(MPClient):
|
||||
# it is forwarded to the outputs_queue so we can raise it
|
||||
# from this (run_output_handler) task to shut down the server.
|
||||
outputs = self.outputs_queue.get()
|
||||
|
||||
if isinstance(outputs, Exception):
|
||||
raise self._format_exception(outputs) from None
|
||||
if outputs.wave_complete is not None:
|
||||
|
||||
@@ -312,7 +312,11 @@ class LLMEngine:
|
||||
|
||||
# 4) Record stats
|
||||
with record_function_or_nullcontext("llm_engine step: record_stats"):
|
||||
if self.logger_manager is not None and outputs.scheduler_stats is not None:
|
||||
if (
|
||||
self.logger_manager is not None
|
||||
and outputs.scheduler_stats is not None
|
||||
and len(outputs.outputs) > 0
|
||||
):
|
||||
self.logger_manager.record(
|
||||
scheduler_stats=outputs.scheduler_stats,
|
||||
iteration_stats=iteration_stats,
|
||||
|
||||
Reference in New Issue
Block a user