Support sequence parallelism combined with pipeline parallelism (#18243)
Signed-off-by: cascade812 <cascade812@outlook.com>
This commit is contained in:
@@ -1056,6 +1056,40 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
indices=out_indices,
|
||||
)
|
||||
|
||||
def sync_and_slice_intermediate_tensors(
|
||||
self, num_tokens: int, intermediate_tensors: IntermediateTensors,
|
||||
sync_self: bool) -> IntermediateTensors:
|
||||
|
||||
assert self.intermediate_tensors is not None
|
||||
|
||||
tp = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
enabled_sp = self.vllm_config.compilation_config.pass_config. \
|
||||
enable_sequence_parallelism
|
||||
if enabled_sp:
|
||||
# When sequence parallelism is enabled, we always pad num_tokens
|
||||
# to be a multiple of tensor_parallel_size (tp) earlier
|
||||
assert num_tokens % tp == 0
|
||||
is_residual_scattered = tp > 1 and enabled_sp \
|
||||
and num_tokens % tp == 0
|
||||
|
||||
# When sequence parallelism is enabled, the "residual" tensor is sharded
|
||||
# across tensor parallel ranks, so each rank only needs its own slice.
|
||||
if sync_self:
|
||||
assert intermediate_tensors is not None
|
||||
for k, v in intermediate_tensors.items():
|
||||
is_scattered = "residual" and is_residual_scattered
|
||||
copy_len = num_tokens // tp if is_scattered else \
|
||||
num_tokens
|
||||
self.intermediate_tensors[k][:copy_len].copy_(
|
||||
v[:copy_len], non_blocking=True)
|
||||
|
||||
return IntermediateTensors({
|
||||
k:
|
||||
v[:num_tokens // tp]
|
||||
if k == "residual" and is_residual_scattered else v[:num_tokens]
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
@@ -1131,15 +1165,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if get_pp_group().is_first_rank:
|
||||
intermediate_tensors = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
assert self.intermediate_tensors is not None
|
||||
for k, v in intermediate_tensors.items():
|
||||
self.intermediate_tensors[k][:num_input_tokens].copy_(
|
||||
v[:num_input_tokens], non_blocking=True)
|
||||
intermediate_tensors = IntermediateTensors({
|
||||
k: v[:num_input_tokens]
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||
num_input_tokens, intermediate_tensors, True)
|
||||
|
||||
# Run the decoder.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
@@ -1658,10 +1685,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
batch_size=self.max_num_tokens,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device))
|
||||
intermediate_tensors = IntermediateTensors({
|
||||
k: v[:num_tokens]
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
|
||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||
num_tokens, None, False)
|
||||
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
|
||||
Reference in New Issue
Block a user