[ROCm][AITER] fix wrong argument passed to AITER flash_attn_varlen_func (#31880)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
@@ -54,7 +54,7 @@ class AiterTritonMLAImpl(AiterMLAImpl):
|
||||
k,
|
||||
v,
|
||||
softmax_scale=softmax_scale,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
return_lse=return_softmax_lse,
|
||||
**kwargs,
|
||||
)
|
||||
# Transpose the LSE if Triton MHA is used:
|
||||
|
||||
@@ -236,7 +236,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
k=k,
|
||||
v=v,
|
||||
softmax_scale=softmax_scale,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
return_lse=return_softmax_lse,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user