[Chore] Try remove init_cached_hf_modules (#31786)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-07 12:34:04 +08:00
committed by GitHub
parent 0a2c2dc3f1
commit aafd4d2354
9 changed files with 30 additions and 73 deletions

View File

@@ -29,11 +29,7 @@ class RunaiDummyExecutor(UniProcExecutor):
is_driver_worker=is_driver_worker,
)
wrapper_kwargs = {
"vllm_config": self.vllm_config,
}
self.driver_worker = WorkerWrapperBase(**wrapper_kwargs)
self.driver_worker = WorkerWrapperBase()
self.collective_rpc("init_worker", args=([worker_rpc_kwargs],))
self.collective_rpc("init_device")

View File

@@ -67,7 +67,7 @@ def assert_from_collective_rpc(engine: LLM, closure: Callable, closure_kwargs: d
class DummyExecutor(UniProcExecutor):
def _init_executor(self) -> None:
"""Initialize the worker and load the model."""
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0)
self.driver_worker = WorkerWrapperBase(rpc_rank=0)
distributed_init_method = get_distributed_init_method(get_ip(), get_open_port())
local_rank = 0
# set local rank as the device index if specified

View File

@@ -23,17 +23,6 @@ from vllm.logger import init_logger
logger = init_logger(__name__)
# TODO: This function can be removed if transformer_modules classes are
# serialized by value when communicating between processes
def init_cached_hf_modules() -> None:
"""
Lazy initialization of the Hugging Face modules.
"""
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
def import_pynvml():
"""
Historical comments:

View File

@@ -519,9 +519,7 @@ class WorkerProc:
shared_worker_lock: LockType,
):
self.rank = rank
wrapper = WorkerWrapperBase(
vllm_config=vllm_config, rpc_rank=local_rank, global_rank=rank
)
wrapper = WorkerWrapperBase(rpc_rank=local_rank, global_rank=rank)
# TODO: move `init_worker` to executor level as a collective rpc call
all_kwargs: list[dict] = [
{} for _ in range(vllm_config.parallel_config.world_size)

View File

@@ -208,9 +208,7 @@ class RayDistributedExecutor(Executor):
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote( # type: ignore[attr-defined]
vllm_config=self.vllm_config, rpc_rank=rank
)
)(RayWorkerWrapper).remote(rpc_rank=rank)
else:
worker = ray.remote(
num_cpus=0,
@@ -218,9 +216,8 @@ class RayDistributedExecutor(Executor):
resources={current_platform.ray_device_key: num_gpus},
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote( # type: ignore[attr-defined]
vllm_config=self.vllm_config, rpc_rank=rank
)
)(RayWorkerWrapper).remote(rpc_rank=rank)
worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank))
worker_ips = ray.get(

View File

@@ -26,7 +26,7 @@ logger = init_logger(__name__)
class UniProcExecutor(Executor):
def _init_executor(self) -> None:
"""Initialize the worker and load the model."""
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0)
self.driver_worker = WorkerWrapperBase(rpc_rank=0)
distributed_init_method, rank, local_rank = self._distributed_args()
kwargs = dict(
vllm_config=self.vllm_config,

View File

@@ -85,12 +85,6 @@ class Worker(WorkerBase):
precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
torch.set_float32_matmul_precision(precision)
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils.import_utils import init_cached_hf_modules
init_cached_hf_modules()
# Buffers saved before sleep
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

View File

@@ -85,12 +85,6 @@ class TPUWorker:
else:
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[self.cache_config.cache_dtype]
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils.import_utils import init_cached_hf_modules
init_cached_hf_modules()
# Delay profiler initialization to the start of the profiling.
# This is because in vLLM V1, MP runtime is initialized before the
# TPU Worker is initialized. The profiler server needs to start after

View File

@@ -178,7 +178,6 @@ class WorkerWrapperBase:
def __init__(
self,
vllm_config: VllmConfig,
rpc_rank: int = 0,
global_rank: int | None = None,
) -> None:
@@ -194,21 +193,10 @@ class WorkerWrapperBase:
"""
self.rpc_rank = rpc_rank
self.global_rank = self.rpc_rank if global_rank is None else global_rank
self.worker: WorkerBase | None = None
# do not store this `vllm_config`, `init_worker` will set the final
# one.
# TODO: investigate if we can remove this field in `WorkerWrapperBase`,
# `init_cached_hf_modules` should be unnecessary now.
self.vllm_config: VllmConfig | None = None
# `model_config` can be None in tests
model_config = vllm_config.model_config
if model_config and model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils.import_utils import init_cached_hf_modules
init_cached_hf_modules()
# Initialized after init_worker is called
self.worker: WorkerBase
self.vllm_config: VllmConfig
def shutdown(self) -> None:
if self.worker is not None:
@@ -241,27 +229,34 @@ class WorkerWrapperBase:
Arguments are passed to the worker class constructor.
"""
kwargs = all_kwargs[self.rpc_rank]
self.vllm_config = kwargs.get("vllm_config")
assert self.vllm_config is not None, (
vllm_config: VllmConfig | None = kwargs.get("vllm_config")
assert vllm_config is not None, (
"vllm_config is required to initialize the worker"
)
self.vllm_config.enable_trace_function_call_for_thread()
self.vllm_config = vllm_config
vllm_config.enable_trace_function_call_for_thread()
from vllm.plugins import load_general_plugins
load_general_plugins()
if isinstance(self.vllm_config.parallel_config.worker_cls, str):
worker_class = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_cls
parallel_config = vllm_config.parallel_config
if isinstance(parallel_config.worker_cls, str):
worker_class: type[WorkerBase] = resolve_obj_by_qualname(
parallel_config.worker_cls
)
else:
raise ValueError(
"passing worker_cls is no longer supported. Please pass keep the class in a separate module and pass the qualified name of the class as a string." # noqa: E501
"passing worker_cls is no longer supported. "
"Please pass keep the class in a separate module "
"and pass the qualified name of the class as a string."
)
if self.vllm_config.parallel_config.worker_extension_cls:
if parallel_config.worker_extension_cls:
worker_extension_cls = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_extension_cls
parallel_config.worker_extension_cls
)
extended_calls = []
if worker_extension_cls not in worker_class.__bases__:
@@ -294,7 +289,7 @@ class WorkerWrapperBase:
"This argument is needed for mm_processor_cache_type='shm'."
)
mm_config = self.vllm_config.model_config.multimodal_config
mm_config = vllm_config.model_config.multimodal_config
if mm_config and mm_config.mm_processor_cache_type == "shm":
raise ValueError(msg)
else:
@@ -303,7 +298,7 @@ class WorkerWrapperBase:
self.mm_receiver_cache = None
else:
self.mm_receiver_cache = worker_receiver_cache_from_config(
self.vllm_config,
vllm_config,
MULTIMODAL_REGISTRY,
shared_worker_lock,
)
@@ -311,7 +306,6 @@ class WorkerWrapperBase:
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during worker initialization
self.worker = worker_class(**kwargs)
assert self.worker is not None
def initialize_from_config(self, kv_cache_configs: list[Any]) -> None:
kv_cache_config = kv_cache_configs[self.global_rank]
@@ -358,20 +352,15 @@ class WorkerWrapperBase:
)
def execute_model(
self,
scheduler_output: SchedulerOutput,
*args,
**kwargs,
self, scheduler_output: SchedulerOutput
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
self._apply_mm_cache(scheduler_output)
assert self.worker is not None
return self.worker.execute_model(scheduler_output, *args, **kwargs)
return self.worker.execute_model(scheduler_output)
def reset_mm_cache(self) -> None:
mm_receiver_cache = self.mm_receiver_cache
if mm_receiver_cache is not None:
mm_receiver_cache.clear_cache()
assert self.worker is not None
self.worker.reset_mm_cache()