[V1][Core] Generic mechanism for handling engine utility (#13060)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-02-19 01:09:22 -08:00
committed by GitHub
parent f525c0be8b
commit caf7ff4456
5 changed files with 197 additions and 56 deletions

View File

@@ -5,9 +5,11 @@ import signal
import threading
import time
from concurrent.futures import Future
from inspect import isclass, signature
from multiprocessing.connection import Connection
from typing import Any, List, Optional, Tuple, Type
import msgspec
import psutil
import zmq
import zmq.asyncio
@@ -21,7 +23,7 @@ from vllm.utils import get_exception_traceback, zmq_socket_ctx
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType)
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
@@ -330,19 +332,39 @@ class EngineCoreProc(EngineCore):
self.add_request(request)
elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request)
elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE:
self.reset_prefix_cache()
elif request_type == EngineCoreRequestType.PROFILE:
self.model_executor.profile(request)
elif request_type == EngineCoreRequestType.ADD_LORA:
self.model_executor.add_lora(request)
elif request_type == EngineCoreRequestType.UTILITY:
call_id, method_name, args = request
output = UtilityOutput(call_id)
try:
method = getattr(self, method_name)
output.result = method(
*self._convert_msgspec_args(method, args))
except BaseException as e:
logger.exception("Invocation of %s method failed", method_name)
output.failure_message = (f"Call to {method_name} method"
f" failed: {str(e)}")
self.output_queue.put_nowait(
EngineCoreOutputs(utility_output=output))
@staticmethod
def _convert_msgspec_args(method, args):
"""If a provided arg type doesn't match corresponding target method
arg type, try converting to msgspec object."""
if not args:
return args
arg_types = signature(method).parameters.values()
assert len(args) <= len(arg_types)
return tuple(
msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
and issubclass(p.annotation, msgspec.Struct)
and not isinstance(v, p.annotation) else v
for v, p in zip(args, arg_types))
def process_input_socket(self, input_path: str):
"""Input socket IO thread."""
# Msgpack serialization decoding.
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
add_lora_decoder = MsgpackDecoder(LoRARequest)
generic_decoder = MsgpackDecoder()
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
@@ -352,14 +374,9 @@ class EngineCoreProc(EngineCore):
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
# Deserialize the request data.
decoder = None
if request_type == EngineCoreRequestType.ADD:
decoder = add_request_decoder
elif request_type == EngineCoreRequestType.ADD_LORA:
decoder = add_lora_decoder
else:
decoder = generic_decoder
decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frame.buffer)
# Push to input queue for core busy loop.