[attention][DCP] use AttentionImpl.need_to_return_lse_for_decode (#24372)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -1592,10 +1592,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
|
||||
# recorect dcp attn_out with lse.
|
||||
if self.dcp_world_size > 1:
|
||||
assert lse is not None, (
|
||||
"For a mla backend want to enable"
|
||||
"DCP, it is mandatory that the corresponding decode attn"
|
||||
"kernel return the softmax lse.")
|
||||
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
|
||||
|
||||
# v_up projection
|
||||
|
||||
@@ -133,6 +133,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
|
||||
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
|
||||
Reference in New Issue
Block a user