[Perf] Optimize dcp allocate tensor (#33102)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -195,14 +195,10 @@ def _cp_lse_common(
|
||||
if ctx is None:
|
||||
ctx = CPTritonContext()
|
||||
|
||||
lses = torch.empty(
|
||||
(cp_group.world_size,) + cp_attn_lse.shape,
|
||||
dtype=cp_attn_lse.dtype,
|
||||
device=cp_attn_lse.device,
|
||||
)
|
||||
|
||||
cp_attn_lse = cp_attn_lse.contiguous()
|
||||
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
||||
lses = cp_group.all_gather(cp_attn_lse, dim=0).reshape(
|
||||
(cp_group.world_size,) + cp_attn_lse.shape
|
||||
)
|
||||
out, lse = correct_attn_out(
|
||||
cp_attn_out,
|
||||
lses,
|
||||
|
||||
Reference in New Issue
Block a user