[XPU] Add deepseek_scaling_rope fused kernel (#36612)

Signed-off-by: yitingw1 <yiting.wang@intel.com>
This commit is contained in:
Wang, Yiting
2026-03-16 12:35:08 +08:00
committed by GitHub
parent 0024f39a32
commit 68e1b711f1
2 changed files with 67 additions and 0 deletions

View File

@@ -8,6 +8,7 @@ from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
@@ -54,6 +55,37 @@ if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"):
return torch.empty((M, N), dtype=input.dtype, device=input.device)
def _xpu_ops_deepseek_scaling_rope_impl(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
offsets: torch.Tensor | None,
cos_sin_cache: torch.Tensor | None,
rotary_dim: int,
is_neox_style: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
assert key is not None
return torch.ops._xpu_C.deepseek_scaling_rope(
positions, query, key, offsets, cos_sin_cache, rotary_dim, is_neox_style
)
def _xpu_ops_deepseek_scaling_rope_fake(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
offsets: torch.Tensor | None,
cos_sin_cache: torch.Tensor | None,
rotary_dim: int,
is_neox_style: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
return query, key
# Global flag to ensure ops are registered only once
_OPS_REGISTERED = False
class xpu_ops:
@staticmethod
def flash_attn_varlen_func(
@@ -402,3 +434,21 @@ class xpu_ops:
raw_topk_indices[: topk_indices.shape[0], : topk_indices.shape[1]] = (
topk_indices
)
@staticmethod
def register_ops_once() -> None:
global _OPS_REGISTERED
if not _OPS_REGISTERED:
# register all the custom ops here
direct_register_custom_op(
op_name="xpu_ops_deepseek_scaling_rope",
op_func=_xpu_ops_deepseek_scaling_rope_impl,
mutates_args=[],
fake_impl=_xpu_ops_deepseek_scaling_rope_fake,
dispatch_key=current_platform.dispatch_key,
)
_OPS_REGISTERED = True
xpu_ops.register_ops_once()

View File

@@ -152,6 +152,23 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
key = key_rot
return query, key
def forward_xpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return torch.ops.vllm.xpu_ops_deepseek_scaling_rope(
positions,
query,
key,
offsets,
self._match_cos_sin_cache_dtype(query),
self.rotary_dim,
self.is_neox_style,
)
def forward_hip(
self,
positions: torch.Tensor,