[BugFix] Improve internal DP load balancing (#21617)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -928,7 +928,7 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
):
|
||||
# Counts forward-passes of the model so that we can synchronize
|
||||
# finished with DP peers every N steps.
|
||||
self.counter = 0
|
||||
self.step_counter = 0
|
||||
self.current_wave = 0
|
||||
self.last_counts = (0, 0)
|
||||
|
||||
@@ -999,7 +999,9 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
counts = self.scheduler.get_request_counts()
|
||||
if counts != self.last_counts:
|
||||
self.last_counts = counts
|
||||
stats = SchedulerStats(*counts)
|
||||
stats = SchedulerStats(*counts,
|
||||
step_counter=self.step_counter,
|
||||
current_wave=self.current_wave)
|
||||
self.output_queue.put_nowait(
|
||||
(-1, EngineCoreOutputs(scheduler_stats=stats)))
|
||||
|
||||
@@ -1041,15 +1043,16 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
self.output_queue.put_nowait(
|
||||
(client_index,
|
||||
EngineCoreOutputs(wave_complete=self.current_wave)))
|
||||
# Increment wave count and reset step counter.
|
||||
self.current_wave += 1
|
||||
self.step_counter = 0
|
||||
|
||||
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
|
||||
|
||||
# Optimization - only perform finish-sync all-reduce every 32 steps.
|
||||
self.counter += 1
|
||||
if self.counter != 32:
|
||||
self.step_counter += 1
|
||||
if self.step_counter % 32 != 0:
|
||||
return True
|
||||
self.counter = 0
|
||||
|
||||
return ParallelConfig.has_unfinished_dp(self.dp_group,
|
||||
local_unfinished)
|
||||
|
||||
Reference in New Issue
Block a user