[XPU] Update latest IPEX 2.8 release (#27735)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -151,7 +151,9 @@ class ipex_ops:
|
||||
def rms_norm(
|
||||
input: torch.Tensor, weight: torch.Tensor, epsilon: float
|
||||
) -> torch.Tensor:
|
||||
return ipex.llm.functional.rms_norm(input, weight, epsilon)
|
||||
out = torch.empty_like(input)
|
||||
torch.ops.torch_ipex.rms_norm_vllm(out, input.contiguous(), weight, epsilon)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def fused_add_rms_norm(
|
||||
@@ -160,10 +162,7 @@ class ipex_ops:
|
||||
weight: torch.Tensor,
|
||||
epsilon: float,
|
||||
) -> None:
|
||||
tmp = ipex.llm.functional.add_rms_norm(
|
||||
residual, input, weight, None, epsilon, True
|
||||
)
|
||||
input.copy_(tmp)
|
||||
torch.ops.torch_ipex.fused_add_rms_norm_vllm(input, residual, weight, epsilon)
|
||||
|
||||
@staticmethod
|
||||
def varlen_attention(
|
||||
@@ -296,16 +295,6 @@ class ipex_ops:
|
||||
num_splits=0,
|
||||
s_aux: torch.Tensor | None = None,
|
||||
):
|
||||
if cu_seqlens_k is None:
|
||||
# cu_seqlens_k is not used in ipex kernel.
|
||||
cu_seqlens_k = torch.cumsum(seqused_k, dim=0)
|
||||
cu_seqlens_k = torch.cat(
|
||||
[
|
||||
torch.tensor([0], device=seqused_k.device, dtype=torch.int32),
|
||||
cu_seqlens_k,
|
||||
]
|
||||
).to(torch.int32)
|
||||
|
||||
real_window_size: tuple[int, int]
|
||||
if window_size is None:
|
||||
real_window_size = (-1, -1)
|
||||
@@ -318,7 +307,7 @@ class ipex_ops:
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
seqused_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
softmax_scale,
|
||||
|
||||
Reference in New Issue
Block a user