Support sequence parallelism combined with pipeline parallelism (#18243)

Signed-off-by: cascade812 <cascade812@outlook.com>
This commit is contained in:
cascade
2025-05-17 15:47:25 -07:00
committed by GitHub
parent 66e63e86ec
commit 9ab2c02ff8
3 changed files with 74 additions and 27 deletions

View File

@@ -1056,6 +1056,40 @@ class GPUModelRunner(LoRAModelRunnerMixin):
indices=out_indices,
)
def sync_and_slice_intermediate_tensors(
self, num_tokens: int, intermediate_tensors: IntermediateTensors,
sync_self: bool) -> IntermediateTensors:
assert self.intermediate_tensors is not None
tp = self.vllm_config.parallel_config.tensor_parallel_size
enabled_sp = self.vllm_config.compilation_config.pass_config. \
enable_sequence_parallelism
if enabled_sp:
# When sequence parallelism is enabled, we always pad num_tokens
# to be a multiple of tensor_parallel_size (tp) earlier
assert num_tokens % tp == 0
is_residual_scattered = tp > 1 and enabled_sp \
and num_tokens % tp == 0
# When sequence parallelism is enabled, the "residual" tensor is sharded
# across tensor parallel ranks, so each rank only needs its own slice.
if sync_self:
assert intermediate_tensors is not None
for k, v in intermediate_tensors.items():
is_scattered = "residual" and is_residual_scattered
copy_len = num_tokens // tp if is_scattered else \
num_tokens
self.intermediate_tensors[k][:copy_len].copy_(
v[:copy_len], non_blocking=True)
return IntermediateTensors({
k:
v[:num_tokens // tp]
if k == "residual" and is_residual_scattered else v[:num_tokens]
for k, v in self.intermediate_tensors.items()
})
@torch.inference_mode()
def execute_model(
self,
@@ -1131,15 +1165,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
assert intermediate_tensors is not None
assert self.intermediate_tensors is not None
for k, v in intermediate_tensors.items():
self.intermediate_tensors[k][:num_input_tokens].copy_(
v[:num_input_tokens], non_blocking=True)
intermediate_tensors = IntermediateTensors({
k: v[:num_input_tokens]
for k, v in self.intermediate_tensors.items()
})
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True)
# Run the decoder.
# Use persistent buffers for CUDA graphs.
@@ -1658,10 +1685,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
batch_size=self.max_num_tokens,
dtype=self.model_config.dtype,
device=self.device))
intermediate_tensors = IntermediateTensors({
k: v[:num_tokens]
for k, v in self.intermediate_tensors.items()
})
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_tokens, None, False)
with set_forward_context(attn_metadata,
self.vllm_config,