diff --git a/vllm/v1/attention/ops/common.py b/vllm/v1/attention/ops/common.py index bd6bc864d..46c689ce0 100644 --- a/vllm/v1/attention/ops/common.py +++ b/vllm/v1/attention/ops/common.py @@ -195,14 +195,10 @@ def _cp_lse_common( if ctx is None: ctx = CPTritonContext() - lses = torch.empty( - (cp_group.world_size,) + cp_attn_lse.shape, - dtype=cp_attn_lse.dtype, - device=cp_attn_lse.device, - ) - cp_attn_lse = cp_attn_lse.contiguous() - lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) + lses = cp_group.all_gather(cp_attn_lse, dim=0).reshape( + (cp_group.world_size,) + cp_attn_lse.shape + ) out, lse = correct_attn_out( cp_attn_out, lses,