[Perf] Optimize dcp allocate tensor (#33102)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -195,14 +195,10 @@ def _cp_lse_common(
|
|||||||
if ctx is None:
|
if ctx is None:
|
||||||
ctx = CPTritonContext()
|
ctx = CPTritonContext()
|
||||||
|
|
||||||
lses = torch.empty(
|
|
||||||
(cp_group.world_size,) + cp_attn_lse.shape,
|
|
||||||
dtype=cp_attn_lse.dtype,
|
|
||||||
device=cp_attn_lse.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
cp_attn_lse = cp_attn_lse.contiguous()
|
cp_attn_lse = cp_attn_lse.contiguous()
|
||||||
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
lses = cp_group.all_gather(cp_attn_lse, dim=0).reshape(
|
||||||
|
(cp_group.world_size,) + cp_attn_lse.shape
|
||||||
|
)
|
||||||
out, lse = correct_attn_out(
|
out, lse = correct_attn_out(
|
||||||
cp_attn_out,
|
cp_attn_out,
|
||||||
lses,
|
lses,
|
||||||
|
|||||||
Reference in New Issue
Block a user