[Feature] support sequence parallelism using compilation pass (#16155)
Signed-off-by: cascade812 <cascade812@outlook.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -19,6 +19,12 @@ def tensor_model_parallel_all_gather(input_: torch.Tensor,
|
||||
return get_tp_group().all_gather(input_, dim)
|
||||
|
||||
|
||||
def tensor_model_parallel_reduce_scatter(input_: torch.Tensor,
|
||||
dim: int = -1) -> torch.Tensor:
|
||||
"""Reduce-Scatter the input tensor across model parallel group."""
|
||||
return get_tp_group().reduce_scatter(input_, dim)
|
||||
|
||||
|
||||
def tensor_model_parallel_gather(input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1) -> Optional[torch.Tensor]:
|
||||
|
||||
Reference in New Issue
Block a user