[V1][Core] Generic mechanism for handling engine utility (#13060)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -5,9 +5,11 @@ import signal
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import Future
|
||||
from inspect import isclass, signature
|
||||
from multiprocessing.connection import Connection
|
||||
from typing import Any, List, Optional, Tuple, Type
|
||||
|
||||
import msgspec
|
||||
import psutil
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
@@ -21,7 +23,7 @@ from vllm.utils import get_exception_traceback, zmq_socket_ctx
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
|
||||
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType)
|
||||
EngineCoreRequestType, UtilityOutput)
|
||||
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
@@ -330,19 +332,39 @@ class EngineCoreProc(EngineCore):
|
||||
self.add_request(request)
|
||||
elif request_type == EngineCoreRequestType.ABORT:
|
||||
self.abort_requests(request)
|
||||
elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE:
|
||||
self.reset_prefix_cache()
|
||||
elif request_type == EngineCoreRequestType.PROFILE:
|
||||
self.model_executor.profile(request)
|
||||
elif request_type == EngineCoreRequestType.ADD_LORA:
|
||||
self.model_executor.add_lora(request)
|
||||
elif request_type == EngineCoreRequestType.UTILITY:
|
||||
call_id, method_name, args = request
|
||||
output = UtilityOutput(call_id)
|
||||
try:
|
||||
method = getattr(self, method_name)
|
||||
output.result = method(
|
||||
*self._convert_msgspec_args(method, args))
|
||||
except BaseException as e:
|
||||
logger.exception("Invocation of %s method failed", method_name)
|
||||
output.failure_message = (f"Call to {method_name} method"
|
||||
f" failed: {str(e)}")
|
||||
self.output_queue.put_nowait(
|
||||
EngineCoreOutputs(utility_output=output))
|
||||
|
||||
@staticmethod
|
||||
def _convert_msgspec_args(method, args):
|
||||
"""If a provided arg type doesn't match corresponding target method
|
||||
arg type, try converting to msgspec object."""
|
||||
if not args:
|
||||
return args
|
||||
arg_types = signature(method).parameters.values()
|
||||
assert len(args) <= len(arg_types)
|
||||
return tuple(
|
||||
msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
|
||||
and issubclass(p.annotation, msgspec.Struct)
|
||||
and not isinstance(v, p.annotation) else v
|
||||
for v, p in zip(args, arg_types))
|
||||
|
||||
def process_input_socket(self, input_path: str):
|
||||
"""Input socket IO thread."""
|
||||
|
||||
# Msgpack serialization decoding.
|
||||
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
|
||||
add_lora_decoder = MsgpackDecoder(LoRARequest)
|
||||
generic_decoder = MsgpackDecoder()
|
||||
|
||||
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
|
||||
@@ -352,14 +374,9 @@ class EngineCoreProc(EngineCore):
|
||||
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
|
||||
|
||||
# Deserialize the request data.
|
||||
decoder = None
|
||||
if request_type == EngineCoreRequestType.ADD:
|
||||
decoder = add_request_decoder
|
||||
elif request_type == EngineCoreRequestType.ADD_LORA:
|
||||
decoder = add_lora_decoder
|
||||
else:
|
||||
decoder = generic_decoder
|
||||
|
||||
decoder = add_request_decoder if (
|
||||
request_type
|
||||
== EngineCoreRequestType.ADD) else generic_decoder
|
||||
request = decoder.decode(data_frame.buffer)
|
||||
|
||||
# Push to input queue for core busy loop.
|
||||
|
||||
Reference in New Issue
Block a user