[Attention] MLA decode optimizations (#12528)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: simon-mo <xmo@berkeley.edu>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: simon-mo <simon.mo@hey.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com>
Co-authored-by: simon-mo <xmo@berkeley.edu>
This commit is contained in:
Lucas Wilkinson
2025-01-31 02:49:37 -05:00
committed by GitHub
parent a1fc18c030
commit cabaf4eff3
31 changed files with 2266 additions and 32 deletions

View File

@@ -75,7 +75,8 @@ class RocmPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1) -> str:
kv_cache_dtype, block_size, use_v1,
use_mla) -> str:
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH: