[XPU] Update latest IPEX 2.8 release (#27735)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji
2025-10-30 11:17:13 +08:00
committed by GitHub
parent d7fb10c574
commit b5bae42f91
4 changed files with 14 additions and 20 deletions

View File

@@ -151,7 +151,9 @@ class ipex_ops:
def rms_norm(
input: torch.Tensor, weight: torch.Tensor, epsilon: float
) -> torch.Tensor:
return ipex.llm.functional.rms_norm(input, weight, epsilon)
out = torch.empty_like(input)
torch.ops.torch_ipex.rms_norm_vllm(out, input.contiguous(), weight, epsilon)
return out
@staticmethod
def fused_add_rms_norm(
@@ -160,10 +162,7 @@ class ipex_ops:
weight: torch.Tensor,
epsilon: float,
) -> None:
tmp = ipex.llm.functional.add_rms_norm(
residual, input, weight, None, epsilon, True
)
input.copy_(tmp)
torch.ops.torch_ipex.fused_add_rms_norm_vllm(input, residual, weight, epsilon)
@staticmethod
def varlen_attention(
@@ -296,16 +295,6 @@ class ipex_ops:
num_splits=0,
s_aux: torch.Tensor | None = None,
):
if cu_seqlens_k is None:
# cu_seqlens_k is not used in ipex kernel.
cu_seqlens_k = torch.cumsum(seqused_k, dim=0)
cu_seqlens_k = torch.cat(
[
torch.tensor([0], device=seqused_k.device, dtype=torch.int32),
cu_seqlens_k,
]
).to(torch.int32)
real_window_size: tuple[int, int]
if window_size is None:
real_window_size = (-1, -1)
@@ -318,7 +307,7 @@ class ipex_ops:
k,
v,
cu_seqlens_q,
cu_seqlens_k,
seqused_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,