[XPU] Add deepseek_scaling_rope fused kernel (#36612)
Signed-off-by: yitingw1 <yiting.wang@intel.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -54,6 +55,37 @@ if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"):
|
||||
return torch.empty((M, N), dtype=input.dtype, device=input.device)
|
||||
|
||||
|
||||
def _xpu_ops_deepseek_scaling_rope_impl(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None,
|
||||
offsets: torch.Tensor | None,
|
||||
cos_sin_cache: torch.Tensor | None,
|
||||
rotary_dim: int,
|
||||
is_neox_style: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert key is not None
|
||||
return torch.ops._xpu_C.deepseek_scaling_rope(
|
||||
positions, query, key, offsets, cos_sin_cache, rotary_dim, is_neox_style
|
||||
)
|
||||
|
||||
|
||||
def _xpu_ops_deepseek_scaling_rope_fake(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None,
|
||||
offsets: torch.Tensor | None,
|
||||
cos_sin_cache: torch.Tensor | None,
|
||||
rotary_dim: int,
|
||||
is_neox_style: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return query, key
|
||||
|
||||
|
||||
# Global flag to ensure ops are registered only once
|
||||
_OPS_REGISTERED = False
|
||||
|
||||
|
||||
class xpu_ops:
|
||||
@staticmethod
|
||||
def flash_attn_varlen_func(
|
||||
@@ -402,3 +434,21 @@ class xpu_ops:
|
||||
raw_topk_indices[: topk_indices.shape[0], : topk_indices.shape[1]] = (
|
||||
topk_indices
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def register_ops_once() -> None:
|
||||
global _OPS_REGISTERED
|
||||
if not _OPS_REGISTERED:
|
||||
# register all the custom ops here
|
||||
direct_register_custom_op(
|
||||
op_name="xpu_ops_deepseek_scaling_rope",
|
||||
op_func=_xpu_ops_deepseek_scaling_rope_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_xpu_ops_deepseek_scaling_rope_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
_OPS_REGISTERED = True
|
||||
|
||||
|
||||
xpu_ops.register_ops_once()
|
||||
|
||||
@@ -152,6 +152,23 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
|
||||
key = key_rot
|
||||
return query, key
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
return torch.ops.vllm.xpu_ops_deepseek_scaling_rope(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
offsets,
|
||||
self._match_cos_sin_cache_dtype(query),
|
||||
self.rotary_dim,
|
||||
self.is_neox_style,
|
||||
)
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user