2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2025-03-27 16:14:41 -07:00
|
|
|
import os
|
2024-11-11 18:05:38 -05:00
|
|
|
import queue
|
2024-12-10 01:28:14 -05:00
|
|
|
import signal
|
2025-03-27 16:14:41 -07:00
|
|
|
import sys
|
2024-11-11 18:05:38 -05:00
|
|
|
import threading
|
|
|
|
|
import time
|
2025-02-15 03:59:01 -08:00
|
|
|
from concurrent.futures import Future
|
2025-02-19 01:09:22 -08:00
|
|
|
from inspect import isclass, signature
|
2025-03-27 16:14:41 -07:00
|
|
|
from logging import DEBUG
|
2025-03-29 05:39:14 -05:00
|
|
|
from typing import Any, Callable, Optional, TypeVar, Union
|
2024-11-11 18:05:38 -05:00
|
|
|
|
2025-02-19 01:09:22 -08:00
|
|
|
import msgspec
|
2024-12-27 20:45:08 -05:00
|
|
|
import psutil
|
2024-11-11 18:05:38 -05:00
|
|
|
import zmq
|
|
|
|
|
import zmq.asyncio
|
|
|
|
|
|
2025-03-27 16:14:41 -07:00
|
|
|
from vllm.config import ParallelConfig, VllmConfig
|
|
|
|
|
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
|
|
|
|
from vllm.executor.multiproc_worker_utils import _add_prefix
|
2024-11-11 18:05:38 -05:00
|
|
|
from vllm.logger import init_logger
|
2025-02-14 11:51:12 +05:30
|
|
|
from vllm.lora.request import LoRARequest
|
2024-12-14 03:21:23 -05:00
|
|
|
from vllm.transformers_utils.config import (
|
|
|
|
|
maybe_register_config_serialize_by_value)
|
2025-03-11 19:15:15 -06:00
|
|
|
from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
|
|
|
|
|
zmq_socket_ctx)
|
2025-03-21 19:56:27 +08:00
|
|
|
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
|
|
|
|
|
unify_kv_cache_configs)
|
2025-03-25 14:21:36 -07:00
|
|
|
from vllm.v1.core.sched.interface import SchedulerInterface
|
2025-03-20 17:50:43 -07:00
|
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
|
|
|
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
|
2025-02-09 19:35:56 -08:00
|
|
|
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
2025-02-19 01:09:22 -08:00
|
|
|
EngineCoreRequestType, UtilityOutput)
|
2025-02-13 03:43:24 -08:00
|
|
|
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
|
2024-12-10 01:28:14 -05:00
|
|
|
from vllm.v1.executor.abstract import Executor
|
2025-04-01 15:33:17 +08:00
|
|
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
2025-02-15 03:59:01 -08:00
|
|
|
from vllm.v1.outputs import ModelRunnerOutput
|
2024-11-11 18:05:38 -05:00
|
|
|
from vllm.v1.request import Request, RequestStatus
|
2025-02-09 19:35:56 -08:00
|
|
|
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
2025-03-07 10:19:11 -05:00
|
|
|
from vllm.v1.structured_output import StructuredOutputManager
|
2024-11-11 18:05:38 -05:00
|
|
|
from vllm.version import __version__ as VLLM_VERSION
|
|
|
|
|
|
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
2025-01-12 16:02:02 -05:00
|
|
|
POLLING_TIMEOUT_S = 2.5
|
2024-11-11 18:05:38 -05:00
|
|
|
|
2025-03-29 05:39:14 -05:00
|
|
|
_R = TypeVar('_R') # Return type for collective_rpc
|
|
|
|
|
|
2024-11-11 18:05:38 -05:00
|
|
|
|
|
|
|
|
class EngineCore:
|
|
|
|
|
"""Inner loop of vLLM's Engine."""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
vllm_config: VllmConfig,
|
2025-03-03 01:34:51 +00:00
|
|
|
executor_class: type[Executor],
|
2025-02-11 15:14:00 +00:00
|
|
|
log_stats: bool,
|
2024-11-11 18:05:38 -05:00
|
|
|
):
|
2024-12-11 17:28:00 +08:00
|
|
|
assert vllm_config.model_config.runner_type != "pooling"
|
2024-11-11 18:05:38 -05:00
|
|
|
|
2025-01-28 03:27:41 -05:00
|
|
|
logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
|
2024-11-11 18:05:38 -05:00
|
|
|
VLLM_VERSION, vllm_config)
|
|
|
|
|
|
2025-02-11 15:14:00 +00:00
|
|
|
self.log_stats = log_stats
|
|
|
|
|
|
2024-11-11 18:05:38 -05:00
|
|
|
# Setup Model.
|
|
|
|
|
self.model_executor = executor_class(vllm_config)
|
|
|
|
|
|
|
|
|
|
# Setup KV Caches and update CacheConfig after profiling.
|
2025-04-01 15:33:17 +08:00
|
|
|
num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
|
|
|
|
|
self._initialize_kv_caches(vllm_config)
|
|
|
|
|
|
2024-11-11 18:05:38 -05:00
|
|
|
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
|
|
|
|
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
|
|
|
|
|
|
2025-03-07 10:19:11 -05:00
|
|
|
self.structured_output_manager = StructuredOutputManager(vllm_config)
|
|
|
|
|
|
2024-11-11 18:05:38 -05:00
|
|
|
# Setup scheduler.
|
2025-03-11 19:15:15 -06:00
|
|
|
if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
|
2025-03-12 21:12:13 -06:00
|
|
|
Scheduler = resolve_obj_by_qualname(
|
|
|
|
|
vllm_config.scheduler_config.scheduler_cls)
|
|
|
|
|
else:
|
|
|
|
|
Scheduler = vllm_config.scheduler_config.scheduler_cls
|
|
|
|
|
|
|
|
|
|
# This warning can be removed once the V1 Scheduler interface is
|
|
|
|
|
# finalized and we can maintain support for scheduler classes that
|
|
|
|
|
# implement it
|
|
|
|
|
if Scheduler is not V1Scheduler:
|
2025-03-11 19:15:15 -06:00
|
|
|
logger.warning(
|
|
|
|
|
"Using configured V1 scheduler class %s. "
|
|
|
|
|
"This scheduler interface is not public and "
|
|
|
|
|
"compatibility may not be maintained.",
|
|
|
|
|
vllm_config.scheduler_config.scheduler_cls)
|
2025-03-12 21:12:13 -06:00
|
|
|
|
2025-03-25 14:21:36 -07:00
|
|
|
self.scheduler: SchedulerInterface = Scheduler(
|
2025-01-15 11:29:00 -08:00
|
|
|
scheduler_config=vllm_config.scheduler_config,
|
|
|
|
|
model_config=vllm_config.model_config,
|
|
|
|
|
cache_config=vllm_config.cache_config,
|
|
|
|
|
lora_config=vllm_config.lora_config,
|
2025-04-01 15:33:17 +08:00
|
|
|
kv_cache_config=kv_cache_config,
|
|
|
|
|
structured_output_manager=self.structured_output_manager,
|
2025-03-27 16:14:41 -07:00
|
|
|
include_finished_set=vllm_config.parallel_config.data_parallel_size
|
|
|
|
|
> 1,
|
2025-02-11 15:14:00 +00:00
|
|
|
log_stats=self.log_stats,
|
2025-01-15 11:29:00 -08:00
|
|
|
)
|
2024-11-11 18:05:38 -05:00
|
|
|
|
2025-02-15 03:59:01 -08:00
|
|
|
# Setup MM Input Mapper.
|
2025-02-13 03:43:24 -08:00
|
|
|
self.mm_input_cache_server = MMInputCacheServer(
|
2024-12-17 16:37:59 -08:00
|
|
|
vllm_config.model_config)
|
2024-12-11 19:55:30 -05:00
|
|
|
|
2025-02-15 03:59:01 -08:00
|
|
|
# Setup batch queue for pipeline parallelism.
|
|
|
|
|
# Batch queue for scheduled batches. This enables us to asynchronously
|
|
|
|
|
# schedule and execute batches, and is required by pipeline parallelism
|
|
|
|
|
# to eliminate pipeline bubbles.
|
|
|
|
|
self.batch_queue_size = self.model_executor.max_concurrent_batches
|
2025-03-03 01:34:51 +00:00
|
|
|
self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
|
2025-02-15 03:59:01 -08:00
|
|
|
SchedulerOutput]]] = None
|
|
|
|
|
if self.batch_queue_size > 1:
|
|
|
|
|
logger.info("Batch queue is enabled with size %d",
|
|
|
|
|
self.batch_queue_size)
|
|
|
|
|
self.batch_queue = queue.Queue(self.batch_queue_size)
|
|
|
|
|
|
2025-04-01 15:33:17 +08:00
|
|
|
def _initialize_kv_caches(
|
|
|
|
|
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
|
2024-12-06 02:07:15 -08:00
|
|
|
start = time.time()
|
2024-11-11 18:05:38 -05:00
|
|
|
|
2025-01-17 15:39:35 +08:00
|
|
|
# Get all kv cache needed by the model
|
2025-02-13 00:02:46 -08:00
|
|
|
kv_cache_specs = self.model_executor.get_kv_cache_specs()
|
2025-01-17 15:39:35 +08:00
|
|
|
|
|
|
|
|
# Profiles the peak memory usage of the model to determine how much
|
|
|
|
|
# memory can be allocated for kv cache.
|
2025-02-13 00:02:46 -08:00
|
|
|
available_gpu_memory = self.model_executor.determine_available_memory()
|
2024-11-11 18:05:38 -05:00
|
|
|
|
2025-03-21 19:56:27 +08:00
|
|
|
assert len(kv_cache_specs) == len(available_gpu_memory)
|
2025-01-17 15:39:35 +08:00
|
|
|
# Get the kv cache tensor size
|
2025-03-21 19:56:27 +08:00
|
|
|
kv_cache_configs = [
|
|
|
|
|
get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
|
|
|
|
|
available_gpu_memory_one_worker)
|
|
|
|
|
for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
|
|
|
|
|
zip(kv_cache_specs, available_gpu_memory)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Since we use a shared centralized controller, we need the
|
|
|
|
|
# `kv_cache_config` to be consistent across all workers to make sure
|
|
|
|
|
# all the memory operators can be applied to all workers.
|
|
|
|
|
unify_kv_cache_configs(kv_cache_configs)
|
|
|
|
|
|
|
|
|
|
# All workers have the same kv_cache_config except layer names, so use
|
2025-04-01 15:33:17 +08:00
|
|
|
# an arbitrary one to initialize the scheduler.
|
2025-03-21 19:56:27 +08:00
|
|
|
assert all([
|
|
|
|
|
cfg.num_blocks == kv_cache_configs[0].num_blocks
|
|
|
|
|
for cfg in kv_cache_configs
|
|
|
|
|
])
|
|
|
|
|
num_gpu_blocks = kv_cache_configs[0].num_blocks
|
2024-11-11 18:05:38 -05:00
|
|
|
num_cpu_blocks = 0
|
2025-04-01 15:33:17 +08:00
|
|
|
scheduler_kv_cache_config = kv_cache_configs[0]
|
2025-01-17 15:39:35 +08:00
|
|
|
|
|
|
|
|
# Initialize kv cache and warmup the execution
|
2025-02-23 22:47:24 +08:00
|
|
|
self.model_executor.initialize_from_config(kv_cache_configs)
|
2025-01-17 15:39:35 +08:00
|
|
|
|
2024-12-06 02:07:15 -08:00
|
|
|
elapsed = time.time() - start
|
|
|
|
|
logger.info(("init engine (profile, create kv cache, "
|
|
|
|
|
"warmup model) took %.2f seconds"), elapsed)
|
2025-04-01 15:33:17 +08:00
|
|
|
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
|
2024-11-11 18:05:38 -05:00
|
|
|
|
|
|
|
|
def add_request(self, request: EngineCoreRequest):
|
|
|
|
|
"""Add request to the scheduler."""
|
2024-12-11 19:55:30 -05:00
|
|
|
|
|
|
|
|
if request.mm_hashes is not None:
|
2025-02-13 03:43:24 -08:00
|
|
|
# Here, if hash exists for a multimodal input, then it will be
|
|
|
|
|
# fetched from the cache, else it will be added to the cache.
|
|
|
|
|
# Note that the cache here is mirrored with the client cache, so
|
|
|
|
|
# anything that has a hash must have a HIT cache entry here
|
|
|
|
|
# as well.
|
2024-12-14 17:54:04 +00:00
|
|
|
assert request.mm_inputs is not None
|
2025-02-13 03:43:24 -08:00
|
|
|
request.mm_inputs = self.mm_input_cache_server.get_and_update(
|
2024-12-17 16:37:59 -08:00
|
|
|
request.mm_inputs, request.mm_hashes)
|
2024-12-11 19:55:30 -05:00
|
|
|
|
2024-11-11 18:05:38 -05:00
|
|
|
req = Request.from_engine_core_request(request)
|
2025-03-07 10:19:11 -05:00
|
|
|
if req.use_structured_output:
|
|
|
|
|
# Start grammar compilation asynchronously
|
2025-03-11 13:36:07 -04:00
|
|
|
self.structured_output_manager.grammar_init(req)
|
2024-12-11 19:55:30 -05:00
|
|
|
|
2024-11-11 18:05:38 -05:00
|
|
|
self.scheduler.add_request(req)
|
|
|
|
|
|
2025-03-03 01:34:51 +00:00
|
|
|
def abort_requests(self, request_ids: list[str]):
|
2024-11-11 18:05:38 -05:00
|
|
|
"""Abort requests from the scheduler."""
|
|
|
|
|
|
|
|
|
|
# TODO: The scheduler doesn't really need to know the
|
|
|
|
|
# specific finish reason, TBD whether we propagate that
|
|
|
|
|
# (i.e. client-aborted vs stop criteria met).
|
|
|
|
|
self.scheduler.finish_requests(request_ids,
|
|
|
|
|
RequestStatus.FINISHED_ABORTED)
|
|
|
|
|
|
2025-01-12 16:02:02 -05:00
|
|
|
def step(self) -> EngineCoreOutputs:
|
2024-11-11 18:05:38 -05:00
|
|
|
"""Schedule, execute, and make output."""
|
|
|
|
|
|
2025-03-07 10:56:00 -08:00
|
|
|
# Check for any requests remaining in the scheduler - unfinished,
|
|
|
|
|
# or finished and not yet removed from the batch.
|
|
|
|
|
if not self.scheduler.has_requests():
|
2025-01-12 16:02:02 -05:00
|
|
|
return EngineCoreOutputs(
|
2025-03-07 10:19:11 -05:00
|
|
|
outputs=[],
|
|
|
|
|
scheduler_stats=self.scheduler.make_stats(),
|
|
|
|
|
)
|
2024-11-11 18:05:38 -05:00
|
|
|
scheduler_output = self.scheduler.schedule()
|
|
|
|
|
output = self.model_executor.execute_model(scheduler_output)
|
|
|
|
|
engine_core_outputs = self.scheduler.update_from_output(
|
2025-02-15 03:59:01 -08:00
|
|
|
scheduler_output, output) # type: ignore
|
2025-03-07 10:19:11 -05:00
|
|
|
|
2025-02-15 03:59:01 -08:00
|
|
|
return engine_core_outputs
|
|
|
|
|
|
|
|
|
|
def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:
|
|
|
|
|
"""Schedule and execute batches with the batch queue.
|
|
|
|
|
Note that if nothing to output in this step, None is returned.
|
|
|
|
|
|
|
|
|
|
The execution flow is as follows:
|
|
|
|
|
1. Try to schedule a new batch if there are unscheduled requests
|
|
|
|
|
and the job queue is not full. If a new batch is scheduled, directly
|
|
|
|
|
return an empty engine core output. In other words, we won't check
|
|
|
|
|
and return model outputs before the batch queue is full.
|
|
|
|
|
2. If there is no new scheduled batch, meaning that the batch queue
|
|
|
|
|
is full or no other requests can be scheduled, we block until the first
|
|
|
|
|
batch in the job queue is finished.
|
|
|
|
|
3. Update the scheduler from the output.
|
|
|
|
|
"""
|
|
|
|
|
assert self.batch_queue is not None
|
|
|
|
|
|
|
|
|
|
engine_core_outputs = None
|
|
|
|
|
scheduler_output = None
|
|
|
|
|
# If there are unscheduled requests and the job queue
|
|
|
|
|
# is not full, schedule a new batch. Note that this is not blocking.
|
|
|
|
|
if (self.scheduler.get_num_unscheduled_requests() > 0
|
|
|
|
|
and not self.batch_queue.full()):
|
|
|
|
|
scheduler_output = self.scheduler.schedule()
|
|
|
|
|
if scheduler_output.total_num_scheduled_tokens > 0:
|
|
|
|
|
future = self.model_executor.execute_model(scheduler_output)
|
|
|
|
|
self.batch_queue.put_nowait(
|
|
|
|
|
(future, scheduler_output)) # type: ignore
|
|
|
|
|
|
2025-03-10 19:48:24 -07:00
|
|
|
scheduled_batch = (scheduler_output is not None
|
|
|
|
|
and scheduler_output.total_num_scheduled_tokens > 0)
|
|
|
|
|
|
|
|
|
|
# If no more requests can be scheduled and the job queue is not empty,
|
2025-02-15 03:59:01 -08:00
|
|
|
# block until the first batch in the job queue is finished.
|
2025-03-10 19:48:24 -07:00
|
|
|
if not scheduled_batch and not self.batch_queue.empty():
|
|
|
|
|
future, scheduler_output = self.batch_queue.get_nowait()
|
|
|
|
|
# Blocking until the first result is available.
|
|
|
|
|
model_output = future.result()
|
|
|
|
|
self.batch_queue.task_done()
|
|
|
|
|
engine_core_outputs = self.scheduler.update_from_output(
|
|
|
|
|
scheduler_output, model_output)
|
2025-02-15 03:59:01 -08:00
|
|
|
|
2024-11-11 18:05:38 -05:00
|
|
|
return engine_core_outputs
|
|
|
|
|
|
2024-12-10 01:28:14 -05:00
|
|
|
def shutdown(self):
|
|
|
|
|
self.model_executor.shutdown()
|
|
|
|
|
|
2024-12-14 17:54:04 +00:00
|
|
|
def profile(self, is_start: bool = True):
|
2024-12-10 01:28:14 -05:00
|
|
|
self.model_executor.profile(is_start)
|
2024-11-23 09:16:15 +08:00
|
|
|
|
2025-01-22 10:52:27 -08:00
|
|
|
def reset_prefix_cache(self):
|
|
|
|
|
self.scheduler.reset_prefix_cache()
|
|
|
|
|
|
2025-02-20 12:41:17 +08:00
|
|
|
def sleep(self, level: int = 1):
|
|
|
|
|
self.model_executor.sleep(level)
|
|
|
|
|
|
2025-04-02 01:59:27 -07:00
|
|
|
def wake_up(self, tags: Optional[list[str]] = None):
|
|
|
|
|
self.model_executor.wake_up(tags)
|
2025-02-20 12:41:17 +08:00
|
|
|
|
2025-03-15 09:28:14 -04:00
|
|
|
def is_sleeping(self) -> bool:
|
|
|
|
|
return self.model_executor.is_sleeping
|
|
|
|
|
|
2025-02-22 19:28:59 +08:00
|
|
|
def execute_dummy_batch(self):
|
|
|
|
|
self.model_executor.collective_rpc("execute_dummy_batch")
|
|
|
|
|
|
2025-02-25 13:48:02 +05:30
|
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
|
|
|
return self.model_executor.add_lora(lora_request)
|
|
|
|
|
|
|
|
|
|
def remove_lora(self, lora_id: int) -> bool:
|
|
|
|
|
return self.model_executor.remove_lora(lora_id)
|
|
|
|
|
|
2025-03-03 01:34:51 +00:00
|
|
|
def list_loras(self) -> set[int]:
|
2025-02-25 13:48:02 +05:30
|
|
|
return self.model_executor.list_loras()
|
|
|
|
|
|
|
|
|
|
def pin_lora(self, lora_id: int) -> bool:
|
|
|
|
|
return self.model_executor.pin_lora(lora_id)
|
2025-02-14 11:51:12 +05:30
|
|
|
|
2025-04-03 02:32:10 -05:00
|
|
|
def save_sharded_state(
|
|
|
|
|
self,
|
|
|
|
|
path: str,
|
|
|
|
|
pattern: Optional[str] = None,
|
|
|
|
|
max_size: Optional[int] = None,
|
|
|
|
|
) -> None:
|
|
|
|
|
self.model_executor.save_sharded_state(path=path,
|
|
|
|
|
pattern=pattern,
|
|
|
|
|
max_size=max_size)
|
|
|
|
|
|
2025-03-29 05:39:14 -05:00
|
|
|
def collective_rpc(self,
|
|
|
|
|
method: Union[str, Callable[..., _R]],
|
|
|
|
|
timeout: Optional[float] = None,
|
|
|
|
|
args: tuple = (),
|
|
|
|
|
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
|
|
|
|
return self.model_executor.collective_rpc(method, timeout, args,
|
|
|
|
|
kwargs)
|
|
|
|
|
|
2024-11-11 18:05:38 -05:00
|
|
|
|
|
|
|
|
class EngineCoreProc(EngineCore):
|
|
|
|
|
"""ZMQ-wrapper for running EngineCore in background process."""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
input_path: str,
|
|
|
|
|
output_path: str,
|
2024-12-27 20:45:08 -05:00
|
|
|
vllm_config: VllmConfig,
|
2025-03-03 01:34:51 +00:00
|
|
|
executor_class: type[Executor],
|
2025-02-11 15:14:00 +00:00
|
|
|
log_stats: bool,
|
2025-03-27 16:14:41 -07:00
|
|
|
engine_index: int = 0,
|
2024-11-11 18:05:38 -05:00
|
|
|
):
|
2025-02-11 15:14:00 +00:00
|
|
|
super().__init__(vllm_config, executor_class, log_stats)
|
2024-11-11 18:05:38 -05:00
|
|
|
|
2025-04-04 12:56:43 -07:00
|
|
|
self.step_fn = (self.step if self.batch_queue is None else
|
|
|
|
|
self.step_with_batch_queue)
|
|
|
|
|
|
|
|
|
|
self.global_unfinished_reqs = False
|
|
|
|
|
|
2024-11-11 18:05:38 -05:00
|
|
|
# Background Threads and Queues for IO. These enable us to
|
|
|
|
|
# overlap ZMQ socket IO with GPU since they release the GIL,
|
|
|
|
|
# and to overlap some serialization/deserialization with the
|
|
|
|
|
# model forward pass.
|
|
|
|
|
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
2025-03-03 01:34:51 +00:00
|
|
|
self.input_queue: queue.Queue[tuple[EngineCoreRequestType,
|
2025-02-09 19:35:56 -08:00
|
|
|
Any]] = queue.Queue()
|
2025-01-12 16:02:02 -05:00
|
|
|
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
|
2024-11-11 18:05:38 -05:00
|
|
|
threading.Thread(target=self.process_input_socket,
|
2025-04-04 12:56:43 -07:00
|
|
|
args=(input_path, engine_index),
|
2024-11-11 18:05:38 -05:00
|
|
|
daemon=True).start()
|
|
|
|
|
threading.Thread(target=self.process_output_socket,
|
2025-03-27 16:14:41 -07:00
|
|
|
args=(output_path, engine_index),
|
2024-11-11 18:05:38 -05:00
|
|
|
daemon=True).start()
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
2025-03-27 16:14:41 -07:00
|
|
|
def run_engine_core(*args,
|
|
|
|
|
dp_rank: int = 0,
|
|
|
|
|
local_dp_rank: int = 0,
|
|
|
|
|
**kwargs):
|
2024-11-11 18:05:38 -05:00
|
|
|
"""Launch EngineCore busy loop in background process."""
|
|
|
|
|
|
2024-12-10 01:28:14 -05:00
|
|
|
# Signal handler used for graceful termination.
|
|
|
|
|
# SystemExit exception is only raised once to allow this and worker
|
|
|
|
|
# processes to terminate without error
|
|
|
|
|
shutdown_requested = False
|
|
|
|
|
|
2024-12-14 03:21:23 -05:00
|
|
|
# Ensure we can serialize transformer config after spawning
|
|
|
|
|
maybe_register_config_serialize_by_value()
|
|
|
|
|
|
2024-12-10 01:28:14 -05:00
|
|
|
def signal_handler(signum, frame):
|
|
|
|
|
nonlocal shutdown_requested
|
|
|
|
|
if not shutdown_requested:
|
|
|
|
|
shutdown_requested = True
|
|
|
|
|
raise SystemExit()
|
|
|
|
|
|
|
|
|
|
# Either SIGTERM or SIGINT will terminate the engine_core
|
|
|
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
|
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
|
|
|
|
2024-12-27 20:45:08 -05:00
|
|
|
parent_process = psutil.Process().parent()
|
2025-03-27 16:14:41 -07:00
|
|
|
engine_core: Optional[EngineCoreProc] = None
|
2024-11-11 18:05:38 -05:00
|
|
|
try:
|
2025-03-27 16:14:41 -07:00
|
|
|
parallel_config: ParallelConfig = kwargs[
|
|
|
|
|
"vllm_config"].parallel_config
|
|
|
|
|
if parallel_config.data_parallel_size > 1:
|
|
|
|
|
# Set data parallel rank for this engine process.
|
|
|
|
|
parallel_config.data_parallel_rank = dp_rank
|
|
|
|
|
parallel_config.data_parallel_rank_local = local_dp_rank
|
|
|
|
|
engine_core = DPEngineCoreProc(*args, **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
engine_core = EngineCoreProc(*args, **kwargs)
|
|
|
|
|
|
2024-11-11 18:05:38 -05:00
|
|
|
engine_core.run_busy_loop()
|
|
|
|
|
|
2024-12-10 01:28:14 -05:00
|
|
|
except SystemExit:
|
2024-11-11 18:05:38 -05:00
|
|
|
logger.debug("EngineCore interrupted.")
|
|
|
|
|
|
2024-12-27 20:45:08 -05:00
|
|
|
except Exception:
|
|
|
|
|
traceback = get_exception_traceback()
|
|
|
|
|
logger.error("EngineCore hit an exception: %s", traceback)
|
2025-01-03 16:29:11 -05:00
|
|
|
parent_process.send_signal(signal.SIGUSR1)
|
2024-11-11 18:05:38 -05:00
|
|
|
|
2024-12-10 01:28:14 -05:00
|
|
|
finally:
|
|
|
|
|
if engine_core is not None:
|
|
|
|
|
engine_core.shutdown()
|
|
|
|
|
|
2024-11-11 18:05:38 -05:00
|
|
|
def run_busy_loop(self):
|
|
|
|
|
"""Core busy loop of the EngineCore."""
|
|
|
|
|
|
2024-12-11 18:34:54 -05:00
|
|
|
# Loop until process is sent a SIGINT or SIGTERM
|
|
|
|
|
while True:
|
2024-11-11 18:05:38 -05:00
|
|
|
# 1) Poll the input queue until there is work to do.
|
2025-03-27 16:14:41 -07:00
|
|
|
self._process_input_queue()
|
|
|
|
|
# 2) Step the engine core and return the outputs.
|
|
|
|
|
self._process_engine_step()
|
|
|
|
|
|
|
|
|
|
def _process_input_queue(self):
|
|
|
|
|
"""Exits when an engine step needs to be performed."""
|
|
|
|
|
|
|
|
|
|
waited = False
|
|
|
|
|
while not self.global_unfinished_reqs and not (
|
|
|
|
|
self.scheduler.has_requests()):
|
|
|
|
|
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
|
|
|
|
|
logger.debug("EngineCore waiting for work.")
|
|
|
|
|
waited = True
|
|
|
|
|
req = self.input_queue.get()
|
|
|
|
|
self._handle_client_request(*req)
|
|
|
|
|
|
|
|
|
|
if waited:
|
|
|
|
|
logger.debug(
|
|
|
|
|
"EngineCore loop active - local unfinished: %s, finished: %s.",
|
|
|
|
|
self.scheduler.has_unfinished_requests(),
|
|
|
|
|
self.scheduler.has_finished_requests())
|
|
|
|
|
|
|
|
|
|
# Handle any more client requests.
|
|
|
|
|
while not self.input_queue.empty():
|
|
|
|
|
req = self.input_queue.get_nowait()
|
|
|
|
|
self._handle_client_request(*req)
|
|
|
|
|
|
|
|
|
|
def _process_engine_step(self):
|
|
|
|
|
"""Called only when there are unfinished local requests."""
|
|
|
|
|
|
|
|
|
|
# Step the engine core.
|
|
|
|
|
outputs = self.step_fn()
|
|
|
|
|
# Put EngineCoreOutputs into the output queue.
|
|
|
|
|
if outputs is not None:
|
|
|
|
|
self.output_queue.put_nowait(outputs)
|
2024-11-11 18:05:38 -05:00
|
|
|
|
2025-02-09 19:35:56 -08:00
|
|
|
def _handle_client_request(self, request_type: EngineCoreRequestType,
|
|
|
|
|
request: Any) -> None:
|
|
|
|
|
"""Dispatch request from client."""
|
2024-11-11 18:05:38 -05:00
|
|
|
|
2025-02-09 19:35:56 -08:00
|
|
|
if request_type == EngineCoreRequestType.ADD:
|
2024-11-11 18:05:38 -05:00
|
|
|
self.add_request(request)
|
2025-02-09 19:35:56 -08:00
|
|
|
elif request_type == EngineCoreRequestType.ABORT:
|
2024-11-11 18:05:38 -05:00
|
|
|
self.abort_requests(request)
|
2025-03-27 16:14:41 -07:00
|
|
|
elif request_type == EngineCoreRequestType.START_DP:
|
|
|
|
|
if not self.global_unfinished_reqs:
|
|
|
|
|
logger.debug("EngineCore starting idle loop.")
|
|
|
|
|
self.global_unfinished_reqs = True
|
2025-02-19 01:09:22 -08:00
|
|
|
elif request_type == EngineCoreRequestType.UTILITY:
|
|
|
|
|
call_id, method_name, args = request
|
|
|
|
|
output = UtilityOutput(call_id)
|
|
|
|
|
try:
|
|
|
|
|
method = getattr(self, method_name)
|
|
|
|
|
output.result = method(
|
|
|
|
|
*self._convert_msgspec_args(method, args))
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
logger.exception("Invocation of %s method failed", method_name)
|
|
|
|
|
output.failure_message = (f"Call to {method_name} method"
|
|
|
|
|
f" failed: {str(e)}")
|
|
|
|
|
self.output_queue.put_nowait(
|
|
|
|
|
EngineCoreOutputs(utility_output=output))
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _convert_msgspec_args(method, args):
|
|
|
|
|
"""If a provided arg type doesn't match corresponding target method
|
|
|
|
|
arg type, try converting to msgspec object."""
|
|
|
|
|
if not args:
|
|
|
|
|
return args
|
|
|
|
|
arg_types = signature(method).parameters.values()
|
|
|
|
|
assert len(args) <= len(arg_types)
|
|
|
|
|
return tuple(
|
|
|
|
|
msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
|
|
|
|
|
and issubclass(p.annotation, msgspec.Struct)
|
|
|
|
|
and not isinstance(v, p.annotation) else v
|
|
|
|
|
for v, p in zip(args, arg_types))
|
2024-11-11 18:05:38 -05:00
|
|
|
|
2025-04-04 12:56:43 -07:00
|
|
|
def process_input_socket(self, input_path: str, engine_index: int):
|
2024-11-11 18:05:38 -05:00
|
|
|
"""Input socket IO thread."""
|
|
|
|
|
|
|
|
|
|
# Msgpack serialization decoding.
|
2025-02-09 19:35:56 -08:00
|
|
|
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
|
|
|
|
|
generic_decoder = MsgpackDecoder()
|
2025-04-04 12:56:43 -07:00
|
|
|
identity = engine_index.to_bytes(length=2, byteorder="little")
|
|
|
|
|
|
|
|
|
|
with zmq_socket_ctx(input_path,
|
|
|
|
|
zmq.DEALER,
|
|
|
|
|
identity=identity,
|
|
|
|
|
bind=False) as socket:
|
|
|
|
|
|
|
|
|
|
# Send ready message to front-end once input socket is connected.
|
|
|
|
|
socket.send(b'READY')
|
2024-11-11 18:05:38 -05:00
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
# (RequestType, RequestData)
|
|
|
|
|
type_frame, data_frame = socket.recv_multipart(copy=False)
|
2025-02-09 19:35:56 -08:00
|
|
|
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
|
2024-11-11 18:05:38 -05:00
|
|
|
|
|
|
|
|
# Deserialize the request data.
|
2025-02-19 01:09:22 -08:00
|
|
|
decoder = add_request_decoder if (
|
|
|
|
|
request_type
|
|
|
|
|
== EngineCoreRequestType.ADD) else generic_decoder
|
2025-02-09 19:35:56 -08:00
|
|
|
request = decoder.decode(data_frame.buffer)
|
2024-11-11 18:05:38 -05:00
|
|
|
|
|
|
|
|
# Push to input queue for core busy loop.
|
2025-02-09 19:35:56 -08:00
|
|
|
self.input_queue.put_nowait((request_type, request))
|
2024-11-11 18:05:38 -05:00
|
|
|
|
2025-03-27 16:14:41 -07:00
|
|
|
def process_output_socket(self, output_path: str, engine_index: int):
|
2024-11-11 18:05:38 -05:00
|
|
|
"""Output socket IO thread."""
|
|
|
|
|
|
|
|
|
|
# Msgpack serialization encoding.
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
encoder = MsgpackEncoder()
|
2024-11-11 18:05:38 -05:00
|
|
|
# Reuse send buffer.
|
|
|
|
|
buffer = bytearray()
|
|
|
|
|
|
2024-12-27 20:45:08 -05:00
|
|
|
with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
|
2024-11-11 18:05:38 -05:00
|
|
|
while True:
|
2025-01-12 16:02:02 -05:00
|
|
|
outputs = self.output_queue.get()
|
2025-03-27 16:14:41 -07:00
|
|
|
outputs.engine_index = engine_index
|
2024-11-11 18:05:38 -05:00
|
|
|
encoder.encode_into(outputs, buffer)
|
2025-03-27 16:14:41 -07:00
|
|
|
socket.send(buffer, copy=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DPEngineCoreProc(EngineCoreProc):
|
|
|
|
|
"""ZMQ-wrapper for running EngineCore in background process
|
|
|
|
|
in a data parallel context."""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
input_path: str,
|
|
|
|
|
output_path: str,
|
|
|
|
|
vllm_config: VllmConfig,
|
|
|
|
|
executor_class: type[Executor],
|
|
|
|
|
log_stats: bool,
|
|
|
|
|
):
|
|
|
|
|
# Add process-specific prefix to stdout and stderr before
|
|
|
|
|
# we initialize the engine.
|
|
|
|
|
from multiprocessing import current_process
|
|
|
|
|
process_name = current_process().name
|
|
|
|
|
pid = os.getpid()
|
|
|
|
|
_add_prefix(sys.stdout, process_name, pid)
|
|
|
|
|
_add_prefix(sys.stderr, process_name, pid)
|
|
|
|
|
|
|
|
|
|
dp_size = vllm_config.parallel_config.data_parallel_size
|
|
|
|
|
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
|
|
|
|
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
|
|
|
|
|
|
|
|
|
|
assert dp_size > 1
|
|
|
|
|
assert 0 <= local_dp_rank <= dp_rank < dp_size
|
|
|
|
|
|
|
|
|
|
from vllm.platforms import current_platform
|
|
|
|
|
if current_platform.is_cuda_alike():
|
|
|
|
|
from vllm.platforms.cuda import device_id_to_physical_device_id
|
|
|
|
|
tp_size = vllm_config.parallel_config.tensor_parallel_size
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
|
|
|
|
|
str(device_id_to_physical_device_id(i))
|
|
|
|
|
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
|
|
|
|
|
tp_size))
|
|
|
|
|
|
|
|
|
|
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
|
|
|
|
|
|
|
|
|
# Initialize the engine after setting up environment.
|
|
|
|
|
super().__init__(input_path, output_path, vllm_config, executor_class,
|
|
|
|
|
log_stats, dp_rank)
|
|
|
|
|
|
|
|
|
|
# Counts forward-passes of the model so that we can synchronize
|
|
|
|
|
# finished with DP peers every N steps.
|
|
|
|
|
self.counter = 0
|
|
|
|
|
|
|
|
|
|
def shutdown(self):
|
|
|
|
|
super().shutdown()
|
|
|
|
|
if dp_group := getattr(self, "dp_group", None):
|
|
|
|
|
stateless_destroy_torch_distributed_process_group(dp_group)
|
|
|
|
|
|
|
|
|
|
def run_busy_loop(self):
|
|
|
|
|
"""Core busy loop of the EngineCore for data parallel case."""
|
|
|
|
|
|
|
|
|
|
# Loop until process is sent a SIGINT or SIGTERM
|
|
|
|
|
while True:
|
|
|
|
|
# 1) Poll the input queue until there is work to do.
|
|
|
|
|
self._process_input_queue()
|
|
|
|
|
|
|
|
|
|
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
|
|
|
|
|
|
|
|
|
|
if local_unfinished_reqs:
|
|
|
|
|
# 2) Step the engine core.
|
|
|
|
|
self._process_engine_step()
|
|
|
|
|
|
|
|
|
|
# Check if we have now finished all requests.
|
|
|
|
|
local_unfinished_reqs = (
|
|
|
|
|
self.scheduler.has_unfinished_requests())
|
|
|
|
|
else:
|
|
|
|
|
if self.scheduler.has_finished_requests():
|
|
|
|
|
# There are no unfinished requests, but there are some
|
|
|
|
|
# finished requests remaining to be removed from the
|
|
|
|
|
# batch state. This engine step won't perform a forward
|
|
|
|
|
# pass but will flush the finished requests to ensure
|
|
|
|
|
# up-to-date state is returned in the engine outputs.
|
|
|
|
|
self._process_engine_step()
|
|
|
|
|
|
|
|
|
|
if not self.global_unfinished_reqs:
|
|
|
|
|
# All engines are idle.
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# There must be unfinished requests in DP peers, run a
|
|
|
|
|
# dummy forward pass.
|
|
|
|
|
self.execute_dummy_batch()
|
|
|
|
|
|
|
|
|
|
# 3) All-reduce operation to determine global unfinished reqs.
|
|
|
|
|
self.global_unfinished_reqs = self._has_global_unfinished_reqs(
|
|
|
|
|
local_unfinished_reqs)
|
|
|
|
|
|
|
|
|
|
if not self.global_unfinished_reqs:
|
|
|
|
|
# Notify client that we are pausing the loop.
|
|
|
|
|
self.output_queue.put_nowait(ENGINE_PAUSED_OUTPUTS)
|
|
|
|
|
|
|
|
|
|
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
|
|
|
|
|
|
|
|
|
|
# Optimization - only perform finish-sync all-reduce every 16 steps.
|
|
|
|
|
self.counter += 1
|
|
|
|
|
if self.counter != 16:
|
|
|
|
|
return True
|
|
|
|
|
self.counter = 0
|
|
|
|
|
|
|
|
|
|
return ParallelConfig.has_unfinished_dp(self.dp_group,
|
|
|
|
|
local_unfinished)
|