[Core] Move EngineCoreRequest to Request conversion out of EngineCore (#21627)
Signed-off-by: linzebing <linzebing1995@gmail.com>
This commit is contained in:
@@ -205,8 +205,12 @@ class EngineCore:
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
return self.model_executor.supported_tasks
|
||||
|
||||
def add_request(self, request: EngineCoreRequest):
|
||||
"""Add request to the scheduler."""
|
||||
def add_request(self, request: Request, request_wave: int = 0):
|
||||
"""Add request to the scheduler.
|
||||
|
||||
`request_wave`: indicate which wave of requests this is expected to
|
||||
belong to in DP case
|
||||
"""
|
||||
# Validate the request_id type.
|
||||
if not isinstance(request.request_id, str):
|
||||
raise TypeError(
|
||||
@@ -222,27 +226,12 @@ class EngineCore:
|
||||
raise ValueError(f"Unsupported task: {pooling_params.task!r} "
|
||||
f"Supported tasks: {supported_pooling_tasks}")
|
||||
|
||||
if request.mm_hashes is not None:
|
||||
# Here, if hash exists for a multimodal input, then it will be
|
||||
# fetched from the cache, else it will be added to the cache.
|
||||
# Note that the cache here is mirrored with the client cache, so
|
||||
# anything that has a hash must have a HIT cache entry here
|
||||
# as well.
|
||||
assert request.mm_inputs is not None
|
||||
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
|
||||
request.mm_inputs, request.mm_hashes)
|
||||
|
||||
req = Request.from_engine_core_request(request)
|
||||
if req.use_structured_output:
|
||||
# Start grammar compilation asynchronously
|
||||
self.structured_output_manager.grammar_init(req)
|
||||
|
||||
if req.kv_transfer_params is not None and (
|
||||
if request.kv_transfer_params is not None and (
|
||||
not self.scheduler.get_kv_connector()):
|
||||
logger.warning("Got kv_transfer_params, but no KVConnector found. "
|
||||
"Disabling KVTransfer for this request.")
|
||||
|
||||
self.scheduler.add_request(req)
|
||||
self.scheduler.add_request(request)
|
||||
|
||||
def abort_requests(self, request_ids: list[str]):
|
||||
"""Abort requests from the scheduler."""
|
||||
@@ -414,6 +403,31 @@ class EngineCore:
|
||||
self.model_executor.save_tensorized_model(
|
||||
tensorizer_config=tensorizer_config, )
|
||||
|
||||
def preprocess_add_request(
|
||||
self, request: EngineCoreRequest) -> tuple[Request, int]:
|
||||
"""Preprocess the request.
|
||||
|
||||
This function could be directly used in input processing thread to allow
|
||||
request initialization running in parallel with Model forward
|
||||
"""
|
||||
if request.mm_hashes is not None:
|
||||
assert request.mm_inputs is not None
|
||||
# Note on thread safety: no race condition.
|
||||
# `mm_input_cache_server` is reset at the end of LLMEngine init,
|
||||
# and will only accessed in the input processing thread afterwards.
|
||||
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
|
||||
request.mm_inputs, request.mm_hashes)
|
||||
|
||||
req = Request.from_engine_core_request(request)
|
||||
if req.use_structured_output:
|
||||
# Note on thread safety: no race condition.
|
||||
# `grammar_init` is only invoked in input processing thread. For
|
||||
# `structured_output_manager`, each request is independent and
|
||||
# grammar compilation is async. Scheduler always checks grammar
|
||||
# compilation status before scheduling request.
|
||||
self.structured_output_manager.grammar_init(req)
|
||||
return req, request.current_wave
|
||||
|
||||
|
||||
class EngineCoreProc(EngineCore):
|
||||
"""ZMQ-wrapper for running EngineCore in background process."""
|
||||
@@ -707,7 +721,8 @@ class EngineCoreProc(EngineCore):
|
||||
"""Dispatch request from client."""
|
||||
|
||||
if request_type == EngineCoreRequestType.ADD:
|
||||
self.add_request(request)
|
||||
req, request_wave = request
|
||||
self.add_request(req, request_wave)
|
||||
elif request_type == EngineCoreRequestType.ABORT:
|
||||
self.abort_requests(request)
|
||||
elif request_type == EngineCoreRequestType.UTILITY:
|
||||
@@ -806,10 +821,11 @@ class EngineCoreProc(EngineCore):
|
||||
bytes(type_frame.buffer))
|
||||
|
||||
# Deserialize the request data.
|
||||
decoder = add_request_decoder if (
|
||||
request_type
|
||||
== EngineCoreRequestType.ADD) else generic_decoder
|
||||
request = decoder.decode(data_frames)
|
||||
if request_type == EngineCoreRequestType.ADD:
|
||||
request = add_request_decoder.decode(data_frames)
|
||||
request = self.preprocess_add_request(request)
|
||||
else:
|
||||
request = generic_decoder.decode(data_frames)
|
||||
|
||||
# Push to input queue for core busy loop.
|
||||
self.input_queue.put_nowait((request_type, request))
|
||||
@@ -939,17 +955,17 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
if dp_group := getattr(self, "dp_group", None):
|
||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||
|
||||
def add_request(self, request: EngineCoreRequest):
|
||||
if self.has_coordinator and request.current_wave != self.current_wave:
|
||||
if request.current_wave > self.current_wave:
|
||||
self.current_wave = request.current_wave
|
||||
def add_request(self, request: Request, request_wave: int = 0):
|
||||
if self.has_coordinator and request_wave != self.current_wave:
|
||||
if request_wave > self.current_wave:
|
||||
self.current_wave = request_wave
|
||||
elif not self.engines_running:
|
||||
# Request received for an already-completed wave, notify
|
||||
# front-end that we need to start the next one.
|
||||
self.output_queue.put_nowait(
|
||||
(-1, EngineCoreOutputs(start_wave=self.current_wave)))
|
||||
|
||||
super().add_request(request)
|
||||
super().add_request(request, request_wave)
|
||||
|
||||
def _handle_client_request(self, request_type: EngineCoreRequestType,
|
||||
request: Any) -> None:
|
||||
|
||||
Reference in New Issue
Block a user