[Perf] Optimize dcp allocate tensor (#33102)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-01-27 17:24:41 -05:00
committed by GitHub
parent f5d7049cc1
commit 3a6d5cbefd

View File

@@ -195,14 +195,10 @@ def _cp_lse_common(
if ctx is None: if ctx is None:
ctx = CPTritonContext() ctx = CPTritonContext()
lses = torch.empty(
(cp_group.world_size,) + cp_attn_lse.shape,
dtype=cp_attn_lse.dtype,
device=cp_attn_lse.device,
)
cp_attn_lse = cp_attn_lse.contiguous() cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) lses = cp_group.all_gather(cp_attn_lse, dim=0).reshape(
(cp_group.world_size,) + cp_attn_lse.shape
)
out, lse = correct_attn_out( out, lse = correct_attn_out(
cp_attn_out, cp_attn_out,
lses, lses,