[ROCm][AITER] fix wrong argument passed to AITER flash_attn_varlen_func (#31880)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
@@ -54,7 +54,7 @@ class AiterTritonMLAImpl(AiterMLAImpl):
|
|||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
softmax_scale=softmax_scale,
|
softmax_scale=softmax_scale,
|
||||||
return_softmax_lse=return_softmax_lse,
|
return_lse=return_softmax_lse,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
# Transpose the LSE if Triton MHA is used:
|
# Transpose the LSE if Triton MHA is used:
|
||||||
|
|||||||
@@ -236,7 +236,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
|||||||
k=k,
|
k=k,
|
||||||
v=v,
|
v=v,
|
||||||
softmax_scale=softmax_scale,
|
softmax_scale=softmax_scale,
|
||||||
return_softmax_lse=return_softmax_lse,
|
return_lse=return_softmax_lse,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user