fix RAM OOM when load large models in tensor parallel mode. (#1395)
Co-authored-by: ran_lin <rlin@thoughtworks.com>
This commit is contained in:
@@ -143,6 +143,12 @@ class LLMEngine:
|
||||
"init_model",
|
||||
get_all_outputs=True,
|
||||
)
|
||||
self._run_workers(
|
||||
"load_model",
|
||||
get_all_outputs=True,
|
||||
max_concurrent_workers=self.parallel_config.
|
||||
max_parallel_loading_workers,
|
||||
)
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
@@ -182,6 +188,12 @@ class LLMEngine:
|
||||
"init_model",
|
||||
get_all_outputs=True,
|
||||
)
|
||||
self._run_workers(
|
||||
"load_model",
|
||||
get_all_outputs=True,
|
||||
max_concurrent_workers=self.parallel_config.
|
||||
max_parallel_loading_workers,
|
||||
)
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
@@ -682,16 +694,15 @@ class LLMEngine:
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
def _run_workers(
|
||||
def _run_workers_in_batch(
|
||||
self,
|
||||
workers,
|
||||
method: str,
|
||||
*args,
|
||||
get_all_outputs: bool = False,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers."""
|
||||
):
|
||||
all_outputs = []
|
||||
for worker in self.workers:
|
||||
for worker in workers:
|
||||
if self.parallel_config.worker_use_ray:
|
||||
executor = partial(worker.execute_method.remote, method)
|
||||
else:
|
||||
@@ -699,9 +710,31 @@ class LLMEngine:
|
||||
|
||||
output = executor(*args, **kwargs)
|
||||
all_outputs.append(output)
|
||||
|
||||
if self.parallel_config.worker_use_ray:
|
||||
all_outputs = ray.get(all_outputs)
|
||||
return all_outputs
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
get_all_outputs: bool = False,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers."""
|
||||
all_outputs = []
|
||||
if max_concurrent_workers:
|
||||
work_groups = [
|
||||
self.workers[i:i + max_concurrent_workers]
|
||||
for i in range(0, len(self.workers), max_concurrent_workers)
|
||||
]
|
||||
else:
|
||||
work_groups = [self.workers]
|
||||
|
||||
for workers in work_groups:
|
||||
all_outputs.extend(
|
||||
self._run_workers_in_batch(workers, method, *args, **kwargs))
|
||||
|
||||
if get_all_outputs:
|
||||
return all_outputs
|
||||
|
||||
Reference in New Issue
Block a user