[Misc][XPU] Upgrade to Pytorch 2.5 for xpu backend (#9823)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: yan ma <yan.ma@intel.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -74,20 +74,12 @@ class ipex_ops:
|
||||
assert kv_cache_dtype == "auto"
|
||||
num_heads = out.size(1)
|
||||
num_queries_per_tokens = num_heads // num_kv_heads
|
||||
head_mapping = torch.arange(
|
||||
0,
|
||||
num_kv_heads,
|
||||
device=query.device,
|
||||
dtype=torch.int32,
|
||||
).view(num_kv_heads,
|
||||
1).repeat_interleave(num_queries_per_tokens).flatten()
|
||||
# todo: ipex will refactor namespace
|
||||
torch.xpu.paged_attention_v1( # type: ignore
|
||||
ipex.llm.modules.PagedAttention.single_query_kv_attention(
|
||||
out,
|
||||
query.contiguous(),
|
||||
key_cache.view_as(value_cache),
|
||||
value_cache,
|
||||
head_mapping,
|
||||
num_queries_per_tokens,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
@@ -124,26 +116,15 @@ class ipex_ops:
|
||||
assert kv_cache_dtype == "auto"
|
||||
num_heads = out.size(1)
|
||||
num_queries_per_tokens = num_heads // num_kv_heads
|
||||
head_mapping = torch.arange(
|
||||
0,
|
||||
num_kv_heads,
|
||||
dtype=torch.int32,
|
||||
device=query.device,
|
||||
).view(num_kv_heads,
|
||||
1).repeat_interleave(num_queries_per_tokens).flatten()
|
||||
# todo: ipex will refactor namespace
|
||||
torch.xpu.paged_attention_v2( # type: ignore
|
||||
ipex.llm.modules.PagedAttention.single_query_kv_attention(
|
||||
out,
|
||||
exp_sum,
|
||||
max_logits,
|
||||
tmp_out,
|
||||
query.contiguous(),
|
||||
key_cache.view_as(value_cache),
|
||||
value_cache,
|
||||
head_mapping,
|
||||
num_queries_per_tokens,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
scale,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
@@ -202,6 +183,7 @@ class ipex_ops:
|
||||
is_causal: bool,
|
||||
return_softmax: bool,
|
||||
gen_: torch.Generator,
|
||||
logits_soft_cap: float,
|
||||
) -> None:
|
||||
ipex.llm.functional.varlen_attention(query.contiguous(),
|
||||
key.contiguous(),
|
||||
@@ -210,7 +192,8 @@ class ipex_ops:
|
||||
max_seqlen_q, max_seqlen_k,
|
||||
pdropout, softmax_scale,
|
||||
zero_tensors, is_causal,
|
||||
return_softmax, gen_)
|
||||
return_softmax, gen_,
|
||||
logits_soft_cap)
|
||||
|
||||
@staticmethod
|
||||
def reshape_and_cache(
|
||||
|
||||
Reference in New Issue
Block a user