[Kernel][CPU] CPU MLA (#14744)

Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
This commit is contained in:
Thien Tran
2025-03-25 17:34:59 +08:00
committed by GitHub
parent 4157f563b4
commit 4f044b1d67
15 changed files with 1010 additions and 17 deletions

View File

@@ -187,15 +187,28 @@ class ipex_ops:
gen_: torch.Generator,
logits_soft_cap: float,
) -> None:
ipex.llm.functional.varlen_attention(query.contiguous(),
key.contiguous(),
value.contiguous(), out,
seqlen_q.int(), seqlen_k.int(),
max_seqlen_q, max_seqlen_k,
pdropout, softmax_scale,
zero_tensors, is_causal,
return_softmax, gen_,
logits_soft_cap)
if ipex.__version__.endswith("cpu"):
if logits_soft_cap != 0.0:
raise ValueError("IPEX CPU does not support logits_soft_cap")
ipex.llm.functional.varlen_attention(query.contiguous(),
key.contiguous(),
value.contiguous(), out,
seqlen_q.int(),
seqlen_k.int(), max_seqlen_q,
max_seqlen_k, pdropout,
softmax_scale, zero_tensors,
is_causal, return_softmax,
gen_)
else: # XPU build
ipex.llm.functional.varlen_attention(query.contiguous(),
key.contiguous(),
value.contiguous(), out,
seqlen_q.int(),
seqlen_k.int(), max_seqlen_q,
max_seqlen_k, pdropout,
softmax_scale, zero_tensors,
is_causal, return_softmax,
gen_, logits_soft_cap)
@staticmethod
def reshape_and_cache(