[Attention] FA4 integration (#32974)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -2014,7 +2014,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
# RoCM and the latter has an additional parameter to control
|
||||
# FA2 vs FA3
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
self.vllm_flash_attn_version = get_flash_attn_version(
|
||||
head_size=self.qk_head_dim
|
||||
)
|
||||
if self.vllm_flash_attn_version is not None:
|
||||
self.flash_attn_varlen_func = functools.partial(
|
||||
flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version
|
||||
|
||||
Reference in New Issue
Block a user