[Misc][XPU] Upgrade to Pytorch 2.5 for xpu backend (#9823)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: yan ma <yan.ma@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Yan Ma
2024-11-07 09:29:03 +08:00
committed by GitHub
parent 4ab3256644
commit d3859f1891
4 changed files with 43 additions and 46 deletions

View File

@@ -74,20 +74,12 @@ class ipex_ops:
assert kv_cache_dtype == "auto"
num_heads = out.size(1)
num_queries_per_tokens = num_heads // num_kv_heads
head_mapping = torch.arange(
0,
num_kv_heads,
device=query.device,
dtype=torch.int32,
).view(num_kv_heads,
1).repeat_interleave(num_queries_per_tokens).flatten()
# todo: ipex will refactor namespace
torch.xpu.paged_attention_v1( # type: ignore
ipex.llm.modules.PagedAttention.single_query_kv_attention(
out,
query.contiguous(),
key_cache.view_as(value_cache),
value_cache,
head_mapping,
num_queries_per_tokens,
scale,
block_tables,
context_lens,
@@ -124,26 +116,15 @@ class ipex_ops:
assert kv_cache_dtype == "auto"
num_heads = out.size(1)
num_queries_per_tokens = num_heads // num_kv_heads
head_mapping = torch.arange(
0,
num_kv_heads,
dtype=torch.int32,
device=query.device,
).view(num_kv_heads,
1).repeat_interleave(num_queries_per_tokens).flatten()
# todo: ipex will refactor namespace
torch.xpu.paged_attention_v2( # type: ignore
ipex.llm.modules.PagedAttention.single_query_kv_attention(
out,
exp_sum,
max_logits,
tmp_out,
query.contiguous(),
key_cache.view_as(value_cache),
value_cache,
head_mapping,
num_queries_per_tokens,
scale,
block_tables,
context_lens,
scale,
block_size,
max_context_len,
alibi_slopes,
@@ -202,6 +183,7 @@ class ipex_ops:
is_causal: bool,
return_softmax: bool,
gen_: torch.Generator,
logits_soft_cap: float,
) -> None:
ipex.llm.functional.varlen_attention(query.contiguous(),
key.contiguous(),
@@ -210,7 +192,8 @@ class ipex_ops:
max_seqlen_q, max_seqlen_k,
pdropout, softmax_scale,
zero_tensors, is_causal,
return_softmax, gen_)
return_softmax, gen_,
logits_soft_cap)
@staticmethod
def reshape_and_cache(