fix RAM OOM when load large models in tensor parallel mode. (#1395)

Co-authored-by: ran_lin <rlin@thoughtworks.com>
This commit is contained in:
boydfd
2023-11-21 11:02:42 +08:00
committed by GitHub
parent 819b18e7ba
commit 4bb6b67188
4 changed files with 52 additions and 7 deletions

View File

@@ -143,6 +143,12 @@ class LLMEngine:
"init_model",
get_all_outputs=True,
)
self._run_workers(
"load_model",
get_all_outputs=True,
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers,
)
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
@@ -182,6 +188,12 @@ class LLMEngine:
"init_model",
get_all_outputs=True,
)
self._run_workers(
"load_model",
get_all_outputs=True,
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers,
)
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
@@ -682,16 +694,15 @@ class LLMEngine:
seq.status = SequenceStatus.FINISHED_STOPPED
return
def _run_workers(
def _run_workers_in_batch(
self,
workers,
method: str,
*args,
get_all_outputs: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
):
all_outputs = []
for worker in self.workers:
for worker in workers:
if self.parallel_config.worker_use_ray:
executor = partial(worker.execute_method.remote, method)
else:
@@ -699,9 +710,31 @@ class LLMEngine:
output = executor(*args, **kwargs)
all_outputs.append(output)
if self.parallel_config.worker_use_ray:
all_outputs = ray.get(all_outputs)
return all_outputs
def _run_workers(
self,
method: str,
*args,
get_all_outputs: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
all_outputs = []
if max_concurrent_workers:
work_groups = [
self.workers[i:i + max_concurrent_workers]
for i in range(0, len(self.workers), max_concurrent_workers)
]
else:
work_groups = [self.workers]
for workers in work_groups:
all_outputs.extend(
self._run_workers_in_batch(workers, method, *args, **kwargs))
if get_all_outputs:
return all_outputs