[Attention] MLA decode optimizations (#12528)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: simon-mo <xmo@berkeley.edu>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: simon-mo <simon.mo@hey.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com>
Co-authored-by: simon-mo <xmo@berkeley.edu>
This commit is contained in:
Lucas Wilkinson
2025-01-31 02:49:37 -05:00
committed by GitHub
parent a1fc18c030
commit cabaf4eff3
31 changed files with 2266 additions and 32 deletions

View File

@@ -27,7 +27,8 @@ class XPUPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool) -> str:
block_size: int, use_v1: bool,
use_mla: bool) -> str:
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
logger.info("Using IPEX attention backend.")