[BugFix] FA2 MLA Accuracy Issue (#18807)

Signed-off-by: LucasWilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson
2025-05-28 04:59:39 -04:00
committed by GitHub
parent aa42561e40
commit ce75efeecb
3 changed files with 16 additions and 8 deletions

View File

@@ -653,10 +653,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if isinstance(attn_out, tuple):
attn_out, lse = attn_out[0], attn_out[1]
# unpad if necessary
if self._pad_v:
attn_out = attn_out[..., :v.shape[-1]]
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
if return_softmax_lse:
@@ -839,6 +835,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
suffix_lse=suffix_lse,
)
# unpad if necessary
if self._pad_v:
output = output[..., :v.shape[-1]]
return output.flatten(start_dim=-2)
@abstractmethod