[Chore] Try remove init_cached_hf_modules (#31786)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -29,11 +29,7 @@ class RunaiDummyExecutor(UniProcExecutor):
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
|
||||
wrapper_kwargs = {
|
||||
"vllm_config": self.vllm_config,
|
||||
}
|
||||
|
||||
self.driver_worker = WorkerWrapperBase(**wrapper_kwargs)
|
||||
self.driver_worker = WorkerWrapperBase()
|
||||
|
||||
self.collective_rpc("init_worker", args=([worker_rpc_kwargs],))
|
||||
self.collective_rpc("init_device")
|
||||
|
||||
@@ -67,7 +67,7 @@ def assert_from_collective_rpc(engine: LLM, closure: Callable, closure_kwargs: d
|
||||
class DummyExecutor(UniProcExecutor):
|
||||
def _init_executor(self) -> None:
|
||||
"""Initialize the worker and load the model."""
|
||||
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0)
|
||||
self.driver_worker = WorkerWrapperBase(rpc_rank=0)
|
||||
distributed_init_method = get_distributed_init_method(get_ip(), get_open_port())
|
||||
local_rank = 0
|
||||
# set local rank as the device index if specified
|
||||
|
||||
@@ -23,17 +23,6 @@ from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# TODO: This function can be removed if transformer_modules classes are
|
||||
# serialized by value when communicating between processes
|
||||
def init_cached_hf_modules() -> None:
|
||||
"""
|
||||
Lazy initialization of the Hugging Face modules.
|
||||
"""
|
||||
from transformers.dynamic_module_utils import init_hf_modules
|
||||
|
||||
init_hf_modules()
|
||||
|
||||
|
||||
def import_pynvml():
|
||||
"""
|
||||
Historical comments:
|
||||
|
||||
@@ -519,9 +519,7 @@ class WorkerProc:
|
||||
shared_worker_lock: LockType,
|
||||
):
|
||||
self.rank = rank
|
||||
wrapper = WorkerWrapperBase(
|
||||
vllm_config=vllm_config, rpc_rank=local_rank, global_rank=rank
|
||||
)
|
||||
wrapper = WorkerWrapperBase(rpc_rank=local_rank, global_rank=rank)
|
||||
# TODO: move `init_worker` to executor level as a collective rpc call
|
||||
all_kwargs: list[dict] = [
|
||||
{} for _ in range(vllm_config.parallel_config.world_size)
|
||||
|
||||
@@ -208,9 +208,7 @@ class RayDistributedExecutor(Executor):
|
||||
num_gpus=num_gpus,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote( # type: ignore[attr-defined]
|
||||
vllm_config=self.vllm_config, rpc_rank=rank
|
||||
)
|
||||
)(RayWorkerWrapper).remote(rpc_rank=rank)
|
||||
else:
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
@@ -218,9 +216,8 @@ class RayDistributedExecutor(Executor):
|
||||
resources={current_platform.ray_device_key: num_gpus},
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote( # type: ignore[attr-defined]
|
||||
vllm_config=self.vllm_config, rpc_rank=rank
|
||||
)
|
||||
)(RayWorkerWrapper).remote(rpc_rank=rank)
|
||||
|
||||
worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank))
|
||||
|
||||
worker_ips = ray.get(
|
||||
|
||||
@@ -26,7 +26,7 @@ logger = init_logger(__name__)
|
||||
class UniProcExecutor(Executor):
|
||||
def _init_executor(self) -> None:
|
||||
"""Initialize the worker and load the model."""
|
||||
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0)
|
||||
self.driver_worker = WorkerWrapperBase(rpc_rank=0)
|
||||
distributed_init_method, rank, local_rank = self._distributed_args()
|
||||
kwargs = dict(
|
||||
vllm_config=self.vllm_config,
|
||||
|
||||
@@ -85,12 +85,6 @@ class Worker(WorkerBase):
|
||||
precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
|
||||
torch.set_float32_matmul_precision(precision)
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils.import_utils import init_cached_hf_modules
|
||||
|
||||
init_cached_hf_modules()
|
||||
|
||||
# Buffers saved before sleep
|
||||
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
|
||||
|
||||
|
||||
@@ -85,12 +85,6 @@ class TPUWorker:
|
||||
else:
|
||||
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[self.cache_config.cache_dtype]
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils.import_utils import init_cached_hf_modules
|
||||
|
||||
init_cached_hf_modules()
|
||||
|
||||
# Delay profiler initialization to the start of the profiling.
|
||||
# This is because in vLLM V1, MP runtime is initialized before the
|
||||
# TPU Worker is initialized. The profiler server needs to start after
|
||||
|
||||
@@ -178,7 +178,6 @@ class WorkerWrapperBase:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
rpc_rank: int = 0,
|
||||
global_rank: int | None = None,
|
||||
) -> None:
|
||||
@@ -194,21 +193,10 @@ class WorkerWrapperBase:
|
||||
"""
|
||||
self.rpc_rank = rpc_rank
|
||||
self.global_rank = self.rpc_rank if global_rank is None else global_rank
|
||||
self.worker: WorkerBase | None = None
|
||||
|
||||
# do not store this `vllm_config`, `init_worker` will set the final
|
||||
# one.
|
||||
# TODO: investigate if we can remove this field in `WorkerWrapperBase`,
|
||||
# `init_cached_hf_modules` should be unnecessary now.
|
||||
self.vllm_config: VllmConfig | None = None
|
||||
|
||||
# `model_config` can be None in tests
|
||||
model_config = vllm_config.model_config
|
||||
if model_config and model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils.import_utils import init_cached_hf_modules
|
||||
|
||||
init_cached_hf_modules()
|
||||
# Initialized after init_worker is called
|
||||
self.worker: WorkerBase
|
||||
self.vllm_config: VllmConfig
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self.worker is not None:
|
||||
@@ -241,27 +229,34 @@ class WorkerWrapperBase:
|
||||
Arguments are passed to the worker class constructor.
|
||||
"""
|
||||
kwargs = all_kwargs[self.rpc_rank]
|
||||
self.vllm_config = kwargs.get("vllm_config")
|
||||
assert self.vllm_config is not None, (
|
||||
|
||||
vllm_config: VllmConfig | None = kwargs.get("vllm_config")
|
||||
assert vllm_config is not None, (
|
||||
"vllm_config is required to initialize the worker"
|
||||
)
|
||||
self.vllm_config.enable_trace_function_call_for_thread()
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
vllm_config.enable_trace_function_call_for_thread()
|
||||
|
||||
from vllm.plugins import load_general_plugins
|
||||
|
||||
load_general_plugins()
|
||||
|
||||
if isinstance(self.vllm_config.parallel_config.worker_cls, str):
|
||||
worker_class = resolve_obj_by_qualname(
|
||||
self.vllm_config.parallel_config.worker_cls
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if isinstance(parallel_config.worker_cls, str):
|
||||
worker_class: type[WorkerBase] = resolve_obj_by_qualname(
|
||||
parallel_config.worker_cls
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"passing worker_cls is no longer supported. Please pass keep the class in a separate module and pass the qualified name of the class as a string." # noqa: E501
|
||||
"passing worker_cls is no longer supported. "
|
||||
"Please pass keep the class in a separate module "
|
||||
"and pass the qualified name of the class as a string."
|
||||
)
|
||||
if self.vllm_config.parallel_config.worker_extension_cls:
|
||||
|
||||
if parallel_config.worker_extension_cls:
|
||||
worker_extension_cls = resolve_obj_by_qualname(
|
||||
self.vllm_config.parallel_config.worker_extension_cls
|
||||
parallel_config.worker_extension_cls
|
||||
)
|
||||
extended_calls = []
|
||||
if worker_extension_cls not in worker_class.__bases__:
|
||||
@@ -294,7 +289,7 @@ class WorkerWrapperBase:
|
||||
"This argument is needed for mm_processor_cache_type='shm'."
|
||||
)
|
||||
|
||||
mm_config = self.vllm_config.model_config.multimodal_config
|
||||
mm_config = vllm_config.model_config.multimodal_config
|
||||
if mm_config and mm_config.mm_processor_cache_type == "shm":
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
@@ -303,7 +298,7 @@ class WorkerWrapperBase:
|
||||
self.mm_receiver_cache = None
|
||||
else:
|
||||
self.mm_receiver_cache = worker_receiver_cache_from_config(
|
||||
self.vllm_config,
|
||||
vllm_config,
|
||||
MULTIMODAL_REGISTRY,
|
||||
shared_worker_lock,
|
||||
)
|
||||
@@ -311,7 +306,6 @@ class WorkerWrapperBase:
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
# To make vLLM config available during worker initialization
|
||||
self.worker = worker_class(**kwargs)
|
||||
assert self.worker is not None
|
||||
|
||||
def initialize_from_config(self, kv_cache_configs: list[Any]) -> None:
|
||||
kv_cache_config = kv_cache_configs[self.global_rank]
|
||||
@@ -358,20 +352,15 @@ class WorkerWrapperBase:
|
||||
)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
*args,
|
||||
**kwargs,
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
|
||||
self._apply_mm_cache(scheduler_output)
|
||||
|
||||
assert self.worker is not None
|
||||
return self.worker.execute_model(scheduler_output, *args, **kwargs)
|
||||
return self.worker.execute_model(scheduler_output)
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
mm_receiver_cache = self.mm_receiver_cache
|
||||
if mm_receiver_cache is not None:
|
||||
mm_receiver_cache.clear_cache()
|
||||
|
||||
assert self.worker is not None
|
||||
self.worker.reset_mm_cache()
|
||||
|
||||
Reference in New Issue
Block a user