Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -10,9 +10,9 @@ import torch.distributed as dist
from vllm.config import VllmConfig
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.uniproc_executor import ( # noqa
ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0)
from vllm.executor.uniproc_executor import ( # noqa
UniProcExecutor as UniProcExecutorV0)
ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0,
)
from vllm.executor.uniproc_executor import UniProcExecutor as UniProcExecutorV0 # noqa
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
@@ -30,21 +30,24 @@ class Executor(ExecutorBase):
def get_class(vllm_config: VllmConfig) -> type["Executor"]:
executor_class: type[Executor]
parallel_config = vllm_config.parallel_config
distributed_executor_backend = (
parallel_config.distributed_executor_backend)
distributed_executor_backend = parallel_config.distributed_executor_backend
# distributed_executor_backend must be set in VllmConfig.__post_init__
if isinstance(distributed_executor_backend, type):
if not issubclass(distributed_executor_backend, ExecutorBase):
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorBase. Got {distributed_executor_backend}.")
f"ExecutorBase. Got {distributed_executor_backend}."
)
executor_class = distributed_executor_backend
elif distributed_executor_backend == "ray":
from vllm.v1.executor.ray_distributed_executor import ( # noqa
RayDistributedExecutor)
RayDistributedExecutor,
)
executor_class = RayDistributedExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
elif distributed_executor_backend == "uni":
executor_class = UniProcExecutor
@@ -53,25 +56,24 @@ class Executor(ExecutorBase):
# to support external launcher
executor_class = ExecutorWithExternalLauncher
elif isinstance(distributed_executor_backend, str):
executor_class = resolve_obj_by_qualname(
distributed_executor_backend)
executor_class = resolve_obj_by_qualname(distributed_executor_backend)
if not issubclass(executor_class, ExecutorBase):
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorBase. Got {executor_class}.")
f"ExecutorBase. Got {executor_class}."
)
else:
raise ValueError("Unknown distributed executor backend: "
f"{distributed_executor_backend}")
raise ValueError(
f"Unknown distributed executor backend: {distributed_executor_backend}"
)
return executor_class
def initialize_from_config(self,
kv_cache_configs: list[KVCacheConfig]) -> None:
def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
self.collective_rpc("initialize_from_config",
args=(kv_cache_configs, ))
self.collective_rpc("initialize_from_config", args=(kv_cache_configs,))
self.collective_rpc("compile_or_warm_up_model")
def register_failure_callback(self, callback: FailureCallback):
@@ -87,12 +89,14 @@ class Executor(ExecutorBase):
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
return self.collective_rpc("get_kv_cache_spec")
def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None,
non_block: bool = False) -> list[Any]:
def collective_rpc(
self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None,
non_block: bool = False,
) -> list[Any]:
raise NotImplementedError
def execute_model(
@@ -100,9 +104,9 @@ class Executor(ExecutorBase):
scheduler_output: SchedulerOutput,
non_block: bool = False,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
output = self.collective_rpc("execute_model",
args=(scheduler_output, ),
non_block=non_block)
output = self.collective_rpc(
"execute_model", args=(scheduler_output,), non_block=non_block
)
return output[0]
def execute_dummy_batch(self) -> None:
@@ -117,7 +121,7 @@ class Executor(ExecutorBase):
return 1
def profile(self, is_start: bool = True):
self.collective_rpc("profile", args=(is_start, ))
self.collective_rpc("profile", args=(is_start,))
class UniProcExecutor(UniProcExecutorV0, Executor):
@@ -125,12 +129,12 @@ class UniProcExecutor(UniProcExecutorV0, Executor):
class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
def determine_available_memory(self) -> list[int]: # in bytes
# same as determine_num_available_blocks in v0,
# we need to get the min across all ranks.
memory = super().determine_available_memory()
from vllm.distributed.parallel_state import get_world_group
cpu_group = get_world_group().cpu_group
memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)

View File

@@ -24,30 +24,36 @@ import torch
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
MessageQueue)
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_pp_group, get_tp_group)
from vllm.distributed import destroy_distributed_environment, destroy_model_parallel
from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue
from vllm.distributed.parallel_state import (
get_dp_group,
get_ep_group,
get_pp_group,
get_tp_group,
)
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import worker_receiver_cache_from_config
from vllm.utils import (_maybe_force_spawn, decorate_logs,
get_distributed_init_method, get_loopback_ip,
get_mp_context, get_open_port, set_process_title)
from vllm.utils import (
_maybe_force_spawn,
decorate_logs,
get_distributed_init_method,
get_loopback_ip,
get_mp_context,
get_open_port,
set_process_title,
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.executor.utils import get_and_update_mm_cache
from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds,
ModelRunnerOutput)
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
class MultiprocExecutor(Executor):
supports_pp: bool = True
def _init_executor(self) -> None:
@@ -65,7 +71,8 @@ class MultiprocExecutor(Executor):
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
f"world_size ({self.world_size}) must be equal to the "
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
f"_parallel_size ({pp_parallel_size}). ")
f"_parallel_size ({pp_parallel_size}). "
)
# Set multiprocessing envs
set_multiprocessing_worker_envs()
@@ -74,14 +81,15 @@ class MultiprocExecutor(Executor):
# Since it only works for single node, we can use the loopback address
# get_loopback_ip() for communication.
distributed_init_method = get_distributed_init_method(
get_loopback_ip(), get_open_port())
get_loopback_ip(), get_open_port()
)
# Initialize worker and set up message queues for SchedulerOutputs
# and ModelRunnerOutputs
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
self.rpc_broadcast_mq = MessageQueue(self.world_size,
self.world_size,
max_chunk_bytes=max_chunk_bytes)
self.rpc_broadcast_mq = MessageQueue(
self.world_size, self.world_size, max_chunk_bytes=max_chunk_bytes
)
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
# Create workers
@@ -99,7 +107,8 @@ class MultiprocExecutor(Executor):
distributed_init_method=distributed_init_method,
input_shm_handle=scheduler_output_handle,
shared_worker_lock=shared_worker_lock,
))
)
)
# Workers must be created before wait_for_ready to avoid
# deadlock, since worker.init_device() does a device sync.
@@ -120,8 +129,7 @@ class MultiprocExecutor(Executor):
for uw in unready_workers:
if uw.death_writer is not None:
uw.death_writer.close()
self._ensure_worker_termination(
[uw.proc for uw in unready_workers])
self._ensure_worker_termination([uw.proc for uw in unready_workers])
# For pipeline parallel, we use a thread pool for asynchronous
# execute_model.
@@ -130,7 +138,8 @@ class MultiprocExecutor(Executor):
# from the response queue
# _async_aggregate_workers_output also assumes a single IO thread
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io")
max_workers=1, thread_name_prefix="mp_exec_io"
)
self.output_rank = self._get_output_rank()
self.has_connector = self.vllm_config.kv_transfer_config is not None
@@ -146,23 +155,22 @@ class MultiprocExecutor(Executor):
sentinels = [h.proc.sentinel for h in workers]
died = multiprocessing.connection.wait(sentinels)
_self = self_ref()
if not _self or getattr(_self, 'shutting_down', False):
if not _self or getattr(_self, "shutting_down", False):
return
_self.is_failed = True
proc_name = next(h.proc.name for h in workers
if h.proc.sentinel == died[0])
proc_name = next(h.proc.name for h in workers if h.proc.sentinel == died[0])
logger.error(
"Worker proc %s died unexpectedly, "
"shutting down executor.", proc_name)
"Worker proc %s died unexpectedly, shutting down executor.", proc_name
)
_self.shutdown()
callback = _self.failure_callback
if callback is not None:
_self.failure_callback = None
callback()
Thread(target=monitor_workers,
daemon=True,
name="MultiprocWorkerMonitor").start()
Thread(
target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor"
).start()
def register_failure_callback(self, callback: FailureCallback):
if self.is_failed:
@@ -175,47 +183,49 @@ class MultiprocExecutor(Executor):
scheduler_output: SchedulerOutput,
non_block: bool = False,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
if not self.has_connector:
# get output only from a single worker (output_rank)
(output, ) = self.collective_rpc(
(output,) = self.collective_rpc(
"execute_model",
args=(scheduler_output, ),
args=(scheduler_output,),
unique_reply_rank=self.output_rank,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
)
return output
# get output from all workers
outputs = self.collective_rpc(
"execute_model",
args=(scheduler_output, ),
args=(scheduler_output,),
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
)
# aggregate all workers output to a single output
if non_block:
return self.kv_output_aggregator.async_aggregate(
outputs, self.output_rank)
return self.kv_output_aggregator.async_aggregate(outputs, self.output_rank)
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
def execute_dummy_batch(self) -> None:
self.collective_rpc("execute_dummy_batch",
unique_reply_rank=self.output_rank)
self.collective_rpc("execute_dummy_batch", unique_reply_rank=self.output_rank)
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
# OPTIMIZATION: Get output only from a single worker (output_rank)
outputs = self.collective_rpc("take_draft_token_ids",
unique_reply_rank=self.output_rank)
outputs = self.collective_rpc(
"take_draft_token_ids", unique_reply_rank=self.output_rank
)
return outputs[0]
def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None,
non_block: bool = False,
unique_reply_rank: Optional[int] = None) -> list[Any]:
def collective_rpc(
self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None,
non_block: bool = False,
unique_reply_rank: Optional[int] = None,
) -> list[Any]:
if self.is_failed:
raise RuntimeError("Executor failed.")
@@ -230,42 +240,53 @@ class MultiprocExecutor(Executor):
send_method = method
else:
send_method = cloudpickle.dumps(
method, protocol=pickle.HIGHEST_PROTOCOL)
method, protocol=pickle.HIGHEST_PROTOCOL
)
self.rpc_broadcast_mq.enqueue(
(send_method, args, kwargs, unique_reply_rank))
(send_method, args, kwargs, unique_reply_rank)
)
workers = (self.workers[unique_reply_rank],
) if unique_reply_rank is not None else self.workers
workers = (
(self.workers[unique_reply_rank],)
if unique_reply_rank is not None
else self.workers
)
responses = []
def get_response(w: WorkerProcHandle,
dequeue_timeout: Optional[float] = None,
cancel_event: Optional[threading.Event] = None):
def get_response(
w: WorkerProcHandle,
dequeue_timeout: Optional[float] = None,
cancel_event: Optional[threading.Event] = None,
):
status, result = w.worker_response_mq.dequeue(
timeout=dequeue_timeout, cancel=cancel_event)
timeout=dequeue_timeout, cancel=cancel_event
)
if status != WorkerProc.ResponseStatus.SUCCESS:
raise RuntimeError(
f"Worker failed with error '{result}', please check the"
" stack trace above for the root cause")
" stack trace above for the root cause"
)
return result
for w in workers:
dequeue_timeout = None if deadline is None else (
deadline - time.monotonic())
dequeue_timeout = (
None if deadline is None else (deadline - time.monotonic())
)
if self.io_thread_pool is not None:
# We must consume worker_response_mq from a single thread.
result = self.io_thread_pool.submit( # type: ignore
get_response, w, dequeue_timeout, self.shutdown_event)
get_response, w, dequeue_timeout, self.shutdown_event
)
if not non_block:
result = result.result()
elif not non_block:
result = get_response(w, dequeue_timeout,
self.shutdown_event)
result = get_response(w, dequeue_timeout, self.shutdown_event)
else:
raise RuntimeError("non_block can only be used when"
" max_concurrent_batches > 1")
raise RuntimeError(
"non_block can only be used when max_concurrent_batches > 1"
)
responses.append(result)
return responses
@@ -302,11 +323,11 @@ class MultiprocExecutor(Executor):
def shutdown(self):
"""Properly shut down the executor and its workers"""
if not getattr(self, 'shutting_down', False):
if not getattr(self, "shutting_down", False):
self.shutting_down = True
# Make sure all the worker processes are terminated first.
if workers := getattr(self, 'workers', None):
if workers := getattr(self, "workers", None):
for w in workers:
# Close death_writer to signal child processes to exit
if w.death_writer is not None:
@@ -348,6 +369,7 @@ class MultiprocExecutor(Executor):
@dataclass
class UnreadyWorkerProcHandle:
"""WorkerProcess handle before READY."""
proc: BaseProcess
rank: int
ready_pipe: Connection
@@ -363,8 +385,8 @@ class WorkerProcHandle:
@classmethod
def from_unready_handle(
cls, unready_handle: UnreadyWorkerProcHandle,
worker_response_mq: MessageQueue) -> "WorkerProcHandle":
cls, unready_handle: UnreadyWorkerProcHandle, worker_response_mq: MessageQueue
) -> "WorkerProcHandle":
return cls(
proc=unready_handle.proc,
rank=unready_handle.rank,
@@ -393,8 +415,7 @@ class WorkerProc:
all_kwargs: list[dict] = [
{} for _ in range(vllm_config.parallel_config.world_size)
]
is_driver_worker = (
rank % vllm_config.parallel_config.tensor_parallel_size == 0)
is_driver_worker = rank % vllm_config.parallel_config.tensor_parallel_size == 0
all_kwargs[rank] = {
"vllm_config": vllm_config,
"local_rank": local_rank,
@@ -407,7 +428,8 @@ class WorkerProc:
# Initialize MessageQueue for receiving SchedulerOutput
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
input_shm_handle, self.worker.rank)
input_shm_handle, self.worker.rank
)
# Initializes a message queue for sending the model output
self.worker_response_mq = MessageQueue(1, 1)
@@ -419,19 +441,22 @@ class WorkerProc:
self.async_output_copy_thread = Thread(
target=self.async_output_busy_loop,
daemon=True,
name="WorkerAsyncOutputCopy")
name="WorkerAsyncOutputCopy",
)
self.async_output_copy_thread.start()
# Initialize multimodal receiver cache if needed
self.mm_receiver_cache = worker_receiver_cache_from_config(
vllm_config, MULTIMODAL_REGISTRY, shared_worker_lock)
vllm_config, MULTIMODAL_REGISTRY, shared_worker_lock
)
# Initialize device
self.worker.init_device()
# Set process title and log prefix
self.setup_proc_title_and_log_prefix(
enable_ep=vllm_config.parallel_config.enable_expert_parallel)
enable_ep=vllm_config.parallel_config.enable_expert_parallel
)
# Load model
self.worker.load_model()
@@ -463,10 +488,12 @@ class WorkerProc:
"shared_worker_lock": shared_worker_lock,
}
# Run EngineCore busy loop in background process.
proc = context.Process(target=WorkerProc.worker_main,
kwargs=process_kwargs,
name=f"VllmWorker-{rank}",
daemon=True)
proc = context.Process(
target=WorkerProc.worker_main,
kwargs=process_kwargs,
name=f"VllmWorker-{rank}",
daemon=True,
)
proc.start()
writer.close()
@@ -476,16 +503,18 @@ class WorkerProc:
@staticmethod
def wait_for_ready(
unready_proc_handles: list[UnreadyWorkerProcHandle]
unready_proc_handles: list[UnreadyWorkerProcHandle],
) -> list[WorkerProcHandle]:
e = Exception("WorkerProc initialization failed due to "
"an exception in a background process. "
"See stack trace for root cause.")
e = Exception(
"WorkerProc initialization failed due to "
"an exception in a background process. "
"See stack trace for root cause."
)
pipes = {handle.ready_pipe: handle for handle in unready_proc_handles}
ready_proc_handles: list[Optional[WorkerProcHandle]] = (
[None] * len(unready_proc_handles))
ready_proc_handles: list[Optional[WorkerProcHandle]] = [None] * len(
unready_proc_handles
)
while pipes:
ready = multiprocessing.connection.wait(pipes.keys())
for pipe in ready:
@@ -499,10 +528,13 @@ class WorkerProc:
# Extract the message queue handle.
worker_response_mq = MessageQueue.create_from_handle(
response["handle"], 0)
response["handle"], 0
)
ready_proc_handles[unready_proc_handle.rank] = (
WorkerProcHandle.from_unready_handle(
unready_proc_handle, worker_response_mq))
unready_proc_handle, worker_response_mq
)
)
except EOFError:
e.__suppress_context__ = True
@@ -523,8 +555,8 @@ class WorkerProc:
@staticmethod
def worker_main(*args, **kwargs):
""" Worker initialization and execution loops.
This runs a background process """
"""Worker initialization and execution loops.
This runs a background process"""
# Signal handler used for graceful termination.
# SystemExit exception is only raised once to allow this and worker
@@ -561,9 +593,9 @@ class WorkerProc:
except Exception as e:
logger.warning("Death monitoring error: %s", e)
death_monitor = Thread(target=monitor_parent_death,
daemon=True,
name="WorkerDeathMonitor")
death_monitor = Thread(
target=monitor_parent_death, daemon=True, name="WorkerDeathMonitor"
)
death_monitor.start()
try:
@@ -571,12 +603,12 @@ class WorkerProc:
worker = WorkerProc(*args, **kwargs)
# Send READY once we know everything is loaded
ready_writer.send({
"status":
WorkerProc.READY_STR,
"handle":
worker.worker_response_mq.export_handle(),
})
ready_writer.send(
{
"status": WorkerProc.READY_STR,
"handle": worker.worker_response_mq.export_handle(),
}
)
# Ensure message queues are ready. Will deadlock if re-ordered.
# Must be kept consistent with the Executor
@@ -653,15 +685,18 @@ class WorkerProc:
"""Main busy loop for Multiprocessing Workers"""
while True:
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue(
cancel=cancel, indefinite=True)
cancel=cancel, indefinite=True
)
try:
if isinstance(method, str):
func = getattr(self.worker, method)
elif isinstance(method, bytes):
func = partial(cloudpickle.loads(method), self.worker)
# retrieve from shm cache if available
if self.mm_receiver_cache is not None \
and func.__name__ == "execute_model":
if (
self.mm_receiver_cache is not None
and func.__name__ == "execute_model"
):
get_and_update_mm_cache(self.mm_receiver_cache, args)
output = func(*args, **kwargs)
except Exception as e:
@@ -701,7 +736,7 @@ class WorkerProc:
def set_multiprocessing_worker_envs():
""" Set up environment variables that should be used when there are workers
"""Set up environment variables that should be used when there are workers
in a multiprocessing environment. This should be called by the parent
process before worker processes are created"""
@@ -714,13 +749,16 @@ def set_multiprocessing_worker_envs():
# impact on performance. The contention is amplified when running in a
# container where CPU limits can cause throttling.
default_omp_num_threads = 1
if "OMP_NUM_THREADS" not in os.environ and (
current_parallelism :=
torch.get_num_threads()) > default_omp_num_threads:
if (
"OMP_NUM_THREADS" not in os.environ
and (current_parallelism := torch.get_num_threads()) > default_omp_num_threads
):
logger.warning(
"Reducing Torch parallelism from %d threads to %d to avoid "
"unnecessary CPU contention. Set OMP_NUM_THREADS in the "
"external environment to tune this value as needed.",
current_parallelism, default_omp_num_threads)
current_parallelism,
default_omp_num_threads,
)
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
torch.set_num_threads(default_omp_num_threads)

View File

@@ -6,7 +6,8 @@ from typing import Optional, Union
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.executor.ray_distributed_executor import ( # noqa
RayDistributedExecutor as RayDistributedExecutorV0)
RayDistributedExecutor as RayDistributedExecutorV0,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
@@ -18,10 +19,10 @@ logger = init_logger(__name__)
class FutureWrapper(Future):
"""A wrapper around Ray output reference to meet the interface
of .execute_model(): The top level (core busy loop) expects .result() api
of .execute_model(): The top level (core busy loop) expects .result() api
to block and return a single output.
If aggregator is provided, the outputs from all workers are aggregated upon
If aggregator is provided, the outputs from all workers are aggregated upon
the result() call. If not only the first worker's output is returned.
"""
@@ -101,8 +102,11 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
return FutureWrapper(refs, self.kv_output_aggregator)
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest) -> None:
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
self._run_workers("reinitialize_distributed", reconfig_request)
if reconfig_request.new_data_parallel_rank == \
ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
if (
reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
):
self.shutdown()

View File

@@ -20,4 +20,5 @@ def get_and_update_mm_cache(
scheduler_output = args[0]
for request_data in scheduler_output.scheduled_new_reqs:
request_data.mm_features = receiver_cache.get_and_update_features(
request_data.mm_features)
request_data.mm_features
)