[BugFix] Fix multi-node offline data-parallel (#18981)

Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Nick Hill
2025-05-31 08:34:52 -07:00
committed by GitHub
parent 8bf507d766
commit 9a1b9b99d7
2 changed files with 11 additions and 8 deletions

View File

@@ -97,10 +97,14 @@ def main(
# with DP, each rank should process different prompts.
# usually all the DP ranks process a full dataset,
# and each rank processes a different part of the dataset.
promts_per_rank = len(prompts) // dp_size
start = global_dp_rank * promts_per_rank
end = start + promts_per_rank
prompts = prompts[start:end]
floor = len(prompts) // dp_size
remainder = len(prompts) % dp_size
# Distribute prompts into even groups.
def start(rank):
return rank * floor + min(rank, remainder)
prompts = prompts[start(global_dp_rank) : start(global_dp_rank + 1)]
if len(prompts) == 0:
# if any rank has no prompts to process,
# we need to set a placeholder prompt