[BugFix] FA2 MLA Accuracy Issue (#18807)
Signed-off-by: LucasWilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
@@ -653,10 +653,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
if isinstance(attn_out, tuple):
|
||||
attn_out, lse = attn_out[0], attn_out[1]
|
||||
|
||||
# unpad if necessary
|
||||
if self._pad_v:
|
||||
attn_out = attn_out[..., :v.shape[-1]]
|
||||
|
||||
# Remain consistent with old `flash_attn_varlen_func` where there
|
||||
# is only one output tensor if `return_softmax_lse` is False.
|
||||
if return_softmax_lse:
|
||||
@@ -839,6 +835,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
suffix_lse=suffix_lse,
|
||||
)
|
||||
|
||||
# unpad if necessary
|
||||
if self._pad_v:
|
||||
output = output[..., :v.shape[-1]]
|
||||
|
||||
return output.flatten(start_dim=-2)
|
||||
|
||||
@abstractmethod
|
||||
|
||||
Reference in New Issue
Block a user