[RFC][ROCm][AITER] Keep all AITER kernels in _aiter_ops class like _custom_ops and _ipex_ops (#24490)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
vllmellm
2025-11-10 17:20:53 +01:00
committed by GitHub
parent 40e2eeeb92
commit f080a83511
25 changed files with 1193 additions and 924 deletions

View File

@@ -14,6 +14,7 @@ import torch.nn.functional as F
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
@@ -55,8 +56,6 @@ from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
logger = init_logger(__name__)
@@ -1089,11 +1088,11 @@ def vllm_topk_softmax(
return topk_weights, topk_indices
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
if is_rocm_aiter_moe_enabled():
from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
return rocm_aiter_topk_softmax
def dispatch_topk_func(
use_rocm_aiter: bool = False,
) -> Callable[..., tuple[torch.Tensor, ...]]:
if use_rocm_aiter:
return rocm_aiter_ops.topk_softmax
return vllm_topk_softmax
@@ -1121,7 +1120,7 @@ def fused_topk(
M, topk, dtype=torch.int32, device=hidden_states.device
)
topk_func = dispatch_topk_func()
topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled())
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
)

View File

@@ -13,6 +13,7 @@ import torch.nn.functional as F
from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.parallel import ExpertPlacementStrategy
from vllm.distributed import (
@@ -41,8 +42,6 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data,
is_rocm_aiter_fusion_shared_expert_enabled,
is_rocm_aiter_moe_enabled,
)
from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator
from vllm.model_executor.layers.quantization.base_config import (
@@ -92,13 +91,11 @@ else:
return topk_ids
eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk,
)
if is_rocm_aiter_moe_enabled():
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk as grouped_topk_aiter,
)
else:
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
if current_platform.is_tpu():
from .moe_pallas import fused_moe as fused_moe_pallas
else:
@@ -463,7 +460,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
if self.rocm_aiter_moe_enabled:
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
@@ -620,13 +618,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
# Padding the weight for better performance on ROCm
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
# Lazy import to avoid importing triton.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
shuffle_weights,
)
if self.rocm_aiter_moe_enabled:
shuffled_w13, shuffled_w2 = shuffle_weights(
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
@@ -1002,6 +996,7 @@ def determine_expert_map(
global_num_experts: int,
expert_placement_strategy: ExpertPlacementStrategy = "linear",
num_fused_shared_experts: int = 0,
return_expert_mask: bool = False,
) -> tuple[int, torch.Tensor | None, torch.Tensor | None]:
"""
Calculates how many experts should be assigned to each rank for EP and
@@ -1064,7 +1059,7 @@ def determine_expert_map(
)
expert_mask = None
if is_rocm_aiter_moe_enabled():
if return_expert_mask:
expert_mask = torch.ones(
(global_num_experts + num_fused_shared_experts + 1,), dtype=torch.int32
)
@@ -1292,14 +1287,18 @@ class FusedMoE(CustomOp):
self.logical_replica_count: torch.Tensor | None = None
# ROCm aiter shared experts fusion
self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
self.aiter_fmoe_shared_expert_enabled = (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
)
self.num_fused_shared_experts = (
n_shared_experts
if n_shared_experts is not None
and is_rocm_aiter_fusion_shared_expert_enabled()
if n_shared_experts is not None and self.aiter_fmoe_shared_expert_enabled
else 0
)
if (
not is_rocm_aiter_fusion_shared_expert_enabled()
not self.aiter_fmoe_shared_expert_enabled
and self.num_fused_shared_experts != 0
):
raise ValueError(
@@ -1346,6 +1345,7 @@ class FusedMoE(CustomOp):
global_num_experts=self.global_num_experts,
expert_placement_strategy=expert_placement_strategy,
num_fused_shared_experts=self.num_fused_shared_experts,
return_expert_mask=self.rocm_aiter_fmoe_enabled,
)
self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map)
@@ -1570,13 +1570,16 @@ class FusedMoE(CustomOp):
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts,
num_fused_shared_experts=self.num_fused_shared_experts,
return_expert_mask=self.rocm_aiter_fmoe_enabled,
)
self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask)
self._init_aiter_shared_experts_topK_buffer(
vllm_config=get_current_vllm_config(), dp_size=get_dp_group().world_size
)
if self.aiter_fmoe_shared_expert_enabled:
self._init_aiter_shared_experts_topK_buffer(
vllm_config=get_current_vllm_config(),
dp_size=get_dp_group().world_size,
)
def _load_per_tensor_weight_scale(
self,
@@ -1753,20 +1756,19 @@ class FusedMoE(CustomOp):
def _init_aiter_shared_experts_topK_buffer(
self, vllm_config: VllmConfig, dp_size: int
):
if is_rocm_aiter_fusion_shared_expert_enabled():
if self.num_fused_shared_experts > 0:
init_aiter_topK_meta_data(
n_routed_experts=self.global_num_experts,
n_shared_experts=self.num_fused_shared_experts,
top_k=self.top_k,
tp_rank=self.ep_rank if self.use_ep else self.tp_rank,
tp_size=self.ep_size if self.use_ep else self.tp_size,
shared_experts_score=1.0,
max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens
* dp_size,
is_EP=self.use_ep,
)
self.local_num_experts += self.num_fused_shared_experts
if self.num_fused_shared_experts > 0:
init_aiter_topK_meta_data(
n_routed_experts=self.global_num_experts,
n_shared_experts=self.num_fused_shared_experts,
top_k=self.top_k,
tp_rank=self.ep_rank if self.use_ep else self.tp_rank,
tp_size=self.ep_size if self.use_ep else self.tp_size,
shared_experts_score=1.0,
max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens
* dp_size,
is_EP=self.use_ep,
)
self.local_num_experts += self.num_fused_shared_experts
@overload
def weight_loader(
@@ -2208,15 +2210,16 @@ class FusedMoE(CustomOp):
elif use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
if is_rocm_aiter_moe_enabled():
if not is_rocm_aiter_fusion_shared_expert_enabled():
if rocm_aiter_ops.is_fused_moe_enabled():
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
assert num_fused_shared_experts == 0
grouped_topk_impl = partial(
grouped_topk_aiter,
rocm_aiter_grouped_topk,
num_fused_shared_experts=num_fused_shared_experts,
)
else:
grouped_topk_impl = grouped_topk
topk_weights, topk_ids = grouped_topk_impl(
hidden_states=hidden_states,
gating_output=router_logits,
@@ -2448,7 +2451,7 @@ class FusedMoE(CustomOp):
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map
if not is_rocm_aiter_moe_enabled()
if not self.rocm_aiter_fmoe_enabled
else self.expert_mask,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
@@ -2612,7 +2615,7 @@ class FusedMoE(CustomOp):
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map
if not is_rocm_aiter_moe_enabled()
if not self.rocm_aiter_fmoe_enabled
else self.expert_mask,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,

View File

@@ -1,17 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import IntEnum
from functools import cache, lru_cache
from functools import lru_cache
import torch
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEQuantConfig,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
class QuantMethod(IntEnum):
@@ -37,27 +35,6 @@ class ActivationMethod(IntEnum):
GELU = 1
@cache
def is_rocm_aiter_moe_enabled() -> bool:
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER_MOE
and envs.VLLM_ROCM_USE_AITER
)
@cache
def use_mxfp4_aiter_moe() -> bool:
return current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
@cache
def is_rocm_aiter_fusion_shared_expert_enabled() -> bool:
return (
envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS and is_rocm_aiter_moe_enabled()
)
aiter_topK_meta_data = None
@@ -114,250 +91,6 @@ def init_aiter_topK_meta_data(
aiter_topK_meta_data = (total_topk_weights, total_topk_ids)
def rocm_aiter_asm_moe_tkw1_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: torch.Tensor | None = None,
fc2_scale: torch.Tensor | None = None,
fc1_smooth_scale: torch.Tensor | None = None,
fc2_smooth_scale: torch.Tensor | None = None,
a16: bool = False,
per_tensor_quant_scale: torch.Tensor | None = None,
expert_mask: torch.Tensor | None = None,
activation_method: int = ActivationMethod.SILU.value,
) -> torch.Tensor:
from aiter import ActivationType
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
activation = ActivationType(activation_method)
return asm_moe_tkw1(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
fc1_smooth_scale=fc1_smooth_scale,
fc2_smooth_scale=fc2_smooth_scale,
a16=a16,
per_tensor_quant_scale=per_tensor_quant_scale,
expert_mask=expert_mask,
activation=activation,
)
def rocm_aiter_asm_moe_tkw1_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: torch.Tensor | None = None,
fc2_scale: torch.Tensor | None = None,
fc1_smooth_scale: torch.Tensor | None = None,
fc2_smooth_scale: torch.Tensor | None = None,
a16: bool = False,
per_tensor_quant_scale: torch.Tensor | None = None,
expert_mask: torch.Tensor | None = None,
activation_method: int = ActivationMethod.SILU.value,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
def rocm_aiter_topk_softmax_impl(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> None:
from aiter import topk_softmax
topk_softmax(
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
)
def rocm_aiter_topk_softmax_fake(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> None:
pass
def rocm_aiter_biased_grouped_topk_impl(
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
from aiter import biased_grouped_topk
biased_grouped_topk(
gating_output,
correction_bias,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
need_renorm,
routed_scaling_factor,
)
def rocm_aiter_biased_grouped_topk_fake(
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
pass
def rocm_aiter_grouped_topk_impl(
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
from aiter import grouped_topk
grouped_topk(
gating_output,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
need_renorm,
scoring_func,
routed_scaling_factor,
)
def rocm_aiter_grouped_topk_fake(
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
pass
def rocm_aiter_fused_moe_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
expert_mask: torch.Tensor | None = None,
activation_method: int = ActivationMethod.SILU.value,
quant_method: int = QuantMethod.NO.value,
doweight_stage1: bool = False,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
) -> torch.Tensor:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
activation = ActivationType(activation_method)
quant_type = QuantType(quant_method)
return fused_moe(
hidden_states,
w1,
w2,
topk_weight,
topk_ids,
expert_mask,
activation,
quant_type,
doweight_stage1,
w1_scale,
w2_scale,
a1_scale,
a2_scale,
)
def rocm_aiter_fused_moe_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
expert_mask: torch.Tensor | None = None,
activation_method: int = ActivationMethod.SILU.value,
quant_method: int = QuantMethod.NO.value,
doweight_stage1: bool = False,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1",
op_func=rocm_aiter_asm_moe_tkw1_impl,
fake_impl=rocm_aiter_asm_moe_tkw1_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_fused_moe",
op_func=rocm_aiter_fused_moe_impl,
fake_impl=rocm_aiter_fused_moe_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_topk_softmax",
op_func=rocm_aiter_topk_softmax_impl,
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
fake_impl=rocm_aiter_topk_softmax_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_biased_grouped_topk",
op_func=rocm_aiter_biased_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"],
fake_impl=rocm_aiter_biased_grouped_topk_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_grouped_topk",
op_func=rocm_aiter_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"],
fake_impl=rocm_aiter_grouped_topk_fake,
)
def rocm_aiter_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
@@ -372,7 +105,10 @@ def rocm_aiter_grouped_topk(
) -> tuple[torch.Tensor, torch.Tensor]:
token = hidden_states.shape[0]
device = hidden_states.device
if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0:
if (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
and num_fused_shared_experts > 0
):
assert aiter_topK_meta_data is not None, (
"AITER topK meta data is not initialized. "
"Please ensure that init_aiter_topK_meta_data "
@@ -397,7 +133,7 @@ def rocm_aiter_grouped_topk(
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
if e_score_correction_bias is not None:
torch.ops.vllm.rocm_aiter_biased_grouped_topk(
rocm_aiter_ops.biased_grouped_topk(
gating_output,
e_score_correction_bias.to(gating_output.dtype),
topk_weights,
@@ -409,7 +145,7 @@ def rocm_aiter_grouped_topk(
)
else:
assert scoring_func == "softmax" or scoring_func == "sigmoid"
torch.ops.vllm.rocm_aiter_grouped_topk(
rocm_aiter_ops.grouped_topk(
gating_output,
topk_weights,
topk_ids,
@@ -420,7 +156,10 @@ def rocm_aiter_grouped_topk(
routed_scaling_factor=routed_scaling_factor,
)
if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0:
if (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
and num_fused_shared_experts > 0
):
return total_topk_weights, total_topk_ids
return topk_weights, topk_ids
@@ -464,7 +203,7 @@ def rocm_aiter_fused_experts(
"Only support topk=1 when `apply_router_weight_on_input` is True"
)
return torch.ops.vllm.rocm_aiter_asm_moe_tkw1(
return rocm_aiter_ops.asm_moe_tkw1(
hidden_states,
w1,
w2,
@@ -482,7 +221,9 @@ def rocm_aiter_fused_experts(
else:
quant_method = QuantMethod.NO.value
# quark moe for mxfp4 w_dtype
if quant_config.use_mxfp4_w4a16:
quant_method = QuantMethod.BLOCK_1X32.value
# w8a8 block-scaled
if quant_config.block_shape is not None and quant_config.use_fp8_w8a8:
assert not apply_router_weight_on_input, (
@@ -507,7 +248,7 @@ def rocm_aiter_fused_experts(
"Only support topk=1 when `apply_router_weight_on_input` is True"
)
return torch.ops.vllm.rocm_aiter_fused_moe(
return rocm_aiter_ops.fused_moe(
hidden_states,
w1,
w2,
@@ -522,39 +263,3 @@ def rocm_aiter_fused_experts(
a2_scale=quant_config.a2_scale,
doweight_stage1=apply_router_weight_on_input,
)
def rocm_aiter_topk_softmax(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> tuple[torch.Tensor, ...]:
torch.ops.vllm.rocm_aiter_topk_softmax(
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
)
return topk_weights, topk_indices
def shuffle_weights(
*tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
) -> tuple[torch.Tensor, ...]:
"""
Applies shuffle_weight function from AITER to each
input tensor and returns them.
Rearranges (shuffles) the input tensor/s
into a specified block layout for optimized computation.
Args:
*tensors: Variable number of torch.Tensor objects.
layout: A pair of integers specifying the block sizes used to divide
the tensors during shuffling. Default is (16, 16).
Returns:
A Tuple of shuffled tensors.
"""
from aiter.ops.shuffle import shuffle_weight
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)

View File

@@ -6,18 +6,13 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
rms_norm_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
def is_rocm_aiter_rmsnorm_enabled() -> bool:
return envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER
def rms_norm(
@@ -58,80 +53,34 @@ def fused_add_rms_norm(
return x, residual
def rocm_aiter_rms_norm_impl(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
def poly_norm(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
import aiter as rocm_aiter
from vllm import _custom_ops as ops
if x.dim() > 2:
x_original_shape = x.shape
x = x.reshape(-1, x_original_shape[-1])
x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
return x.reshape(x_original_shape)
return rocm_aiter.rms_norm(x, weight, variance_epsilon)
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
import aiter as rocm_aiter
residual_out = torch.empty_like(residual)
output = torch.empty_like(x)
rocm_aiter.rmsnorm2d_fwd_with_add(
output, # output
x, # input
residual, # residual input
residual_out, # residual output
out = torch.empty_like(x)
ops.poly_norm(
out,
x,
weight,
bias,
variance_epsilon,
)
return output, residual_out
return out
def rocm_aiter_rms_norm_fake(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
return torch.empty_like(x)
def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(x), torch.empty_like(residual)
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_rms_norm",
op_func=rocm_aiter_rms_norm_impl,
fake_impl=rocm_aiter_rms_norm_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
)
def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
def dispatch_rocm_rmsnorm_func(
with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
):
use_aiter = use_aiter and dtype in [
torch.float16,
torch.bfloat16,
]
if use_aiter and with_fused_add:
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
return rocm_aiter_ops.rms_norm2d_with_add
if use_aiter:
return torch.ops.vllm.rocm_aiter_rms_norm
return rocm_aiter_ops.rms_norm
# fall back to CUDA implementation
if with_fused_add:
@@ -169,11 +118,14 @@ class RMSNorm(CustomOp):
self.weight = nn.Parameter(self.weight)
if current_platform.is_rocm():
aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
with_fused_add=False, dtype=weight_dtype
with_fused_add=False,
dtype=weight_dtype,
use_aiter=aiter_rmsnorm_enabled,
)
self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
with_fused_add=True, dtype=weight_dtype
with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
)
@staticmethod

View File

@@ -12,6 +12,7 @@ from compressed_tensors.quantization import ActivationOrdering, QuantizationStra
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
@@ -582,11 +583,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
)
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# cutlass path
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
@@ -829,12 +827,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# Property to determine if AITER is used
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
shuffle_weights,
)
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)

View File

@@ -7,12 +7,12 @@ import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from torch.nn import Parameter
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
check_aiter_fp8_linear_support,
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
@@ -61,7 +61,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
)
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
if self.weight_block_size is not None:
assert not self.is_static_input_scheme

View File

@@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
@@ -56,7 +57,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
check_aiter_fp8_linear_support,
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
@@ -369,7 +369,7 @@ class Fp8LinearMethod(LinearMethodBase):
if vllm_is_batch_invariant():
self.use_marlin = False
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
self.use_deep_gemm = is_deep_gemm_supported()
self.weight_block_size = self.quant_config.weight_block_size
@@ -869,12 +869,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer: Module) -> None:
# Lazy import to avoid importing triton too early.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
shuffle_weights,
)
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# TODO (rob): refactor block quant into separate class.
if self.block_quant:
@@ -916,7 +912,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
@@ -962,7 +958,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)
@@ -1042,7 +1038,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
start += shard_size
if self.rocm_aiter_moe_enabled:
shuffled_w13, shuffled_w2 = shuffle_weights(
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)

View File

@@ -4,54 +4,14 @@
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
def rocm_aiter_gemm_w8a8_impl(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
from aiter import gemm_a8w8_CK
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
def rocm_aiter_gemm_w8a8_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m = A.shape[0]
n = B.shape[0]
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
return Y
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8",
op_func=rocm_aiter_gemm_w8a8_impl,
fake_impl=rocm_aiter_gemm_w8a8_fake,
)
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
@@ -75,7 +35,7 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
+ "installed on ROCm.",
)
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
if not (envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER):
if not (rocm_aiter_ops.is_linear_enabled()):
return (
False,
"AiterScaledMMLinearKernel is disabled. "
@@ -157,6 +117,4 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return torch.ops.vllm.rocm_aiter_gemm_w8a8(
x_q, w_q.t(), x_s, w_s, bias, out_dtype
)
return rocm_aiter_ops.gemm_w8a8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)

View File

@@ -8,6 +8,7 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
@@ -21,10 +22,6 @@ from vllm.model_executor.layers.fused_moe.config import (
ocp_mx_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
use_mxfp4_aiter_moe,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_moe_fp8_layer_for_marlin,
)
@@ -122,7 +119,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
if current_platform.is_rocm():
self.use_marlin = False
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
def create_weights(
self,
@@ -309,12 +306,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
)
# Property to determine if AITER is used
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
shuffle_weights,
)
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
@@ -470,13 +463,15 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
"not implemented. Please open an issue."
)
self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled()
self.emulate = not current_platform.supports_mx() or not (
use_mxfp4_aiter_moe() and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4"
self.use_rocm_aiter_moe and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4"
)
if self.emulate:
logger.warning_once(
f"The current mode (supports_mx={current_platform.supports_mx()}, "
f"use_mxfp4_aiter_moe={use_mxfp4_aiter_moe()}, "
f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, "
f"ocp_mx_scheme={self.ocp_mx_scheme}) "
"does not support native MXFP4/MXFP6 "
"computation. Simulated weight dequantization and activation "
@@ -656,28 +651,18 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
)
if not self.emulate:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
aiter_acts = {
ActivationType.No.name.lower(): ActivationType.No,
ActivationType.Silu.name.lower(): ActivationType.Silu,
ActivationType.Gelu.name.lower(): ActivationType.Gelu,
}
assert activation in aiter_acts, (
f"Aiter CK fp4 MoE doesn't support activation {activation}"
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
)
out = fused_moe(
out = rocm_aiter_fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
quant_type=QuantType.per_1x32,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
activation=aiter_acts[activation],
doweight_stage1=False,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
quant_config=self.moe_quant_config,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts

View File

@@ -31,6 +31,13 @@ from .quark_scheme import QuarkScheme
logger = init_logger(__name__)
# TODO: move registration of custom op to aiter_ops.py
# `from vllm._aiter_ops import rocm_aiter_ops`
# use `rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()`
# for envs checks which does not require @cache anymore.
# triton kernel is torch compile compatible.
# does not require direct registeration.
# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt`.
@cache
def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
return (

View File

@@ -12,6 +12,7 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
@@ -68,78 +69,6 @@ def cutlass_scaled_mm(
)
def rocm_aiter_gemm_w8a8_blockscale_impl(
input_2d: torch.Tensor,
weight: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
group_size: int,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
def is_aiter_triton_kernel_tuned(n, k):
return (n, k) in [
(1024, 8192),
(2112, 7168),
(3072, 1536),
(32768, 8192),
(4096, 7168),
(4608, 7168),
(512, 7168),
(7168, 2048),
(7168, 256),
(8192, 1024),
(8192, 32768),
]
n, k = weight.shape
if input_scale is not None:
q_input = input_2d
elif not current_platform.is_fp8_fnuz() and is_aiter_triton_kernel_tuned(n, k):
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
# MI350 case uses triton kernel
q_input, input_scale = per_token_group_quant_fp8(
input_2d,
group_size,
column_major_scales=False,
use_ue8m0=False,
)
else:
# MI300 uses tuned AITER ASM/C++ kernel
import aiter as rocm_aiter
from aiter import gemm_a8w8_blockscale, get_hip_quant
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
q_input, input_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8
)
return gemm_a8w8_blockscale(
q_input, weight, input_scale, weight_scale, dtype=output_dtype
)
def rocm_aiter_gemm_w8a8_blockscale_fake(
input_2d: torch.Tensor,
weight: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
group_size: int,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m = input_2d.shape[0]
n = weight.shape[0]
return torch.empty(m, n, dtype=output_dtype, device=input_2d.device)
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8_blockscale",
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
)
# TODO we should be able to change the type of block_size to GroupShape
# after we resolve GroupShape compilation issue
# https://github.com/vllm-project/vllm/issues/25270
@@ -385,14 +314,40 @@ class W8A8BlockFp8LinearOp:
input_scale: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.act_quant_group_shape == GroupShape(1, 128)
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
input_2d,
weight,
input_scale,
weight_scale,
self.act_quant_group_shape.col,
input_2d.dtype,
)
n, k = weight.shape
if input_scale is not None:
q_input = input_2d
# MI350 case uses triton kernel
if (
not current_platform.is_fp8_fnuz()
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
):
q_input, input_scale = per_token_group_quant_fp8(
input_2d,
self.act_quant_group_shape.col,
column_major_scales=False,
use_ue8m0=False,
)
return rocm_aiter_ops.triton_gemm_a8w8_blockscale(
q_input,
weight,
input_scale,
weight_scale,
input_2d.dtype,
)
# MI300 uses tuned AITER ASM/C++ kernel
else:
q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d)
return rocm_aiter_ops.gemm_w8a8_blockscale(
q_input,
weight,
input_scale,
weight_scale,
input_2d.dtype,
)
def _run_triton(
self,
@@ -971,15 +926,6 @@ def requant_weight_ue8m0_inplace(
s_old.copy_(s_requant)
def check_aiter_fp8_linear_support() -> bool:
"""AITER is only supported on ROCm for MI3XX"""
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_LINEAR
)
def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
"""Pad the weight tensor. This is an optimization on ROCm platform, which
can benefit from tensors located far enough from one another in memory"""

View File

@@ -472,7 +472,7 @@ class Fp8LinearOp:
# Example:
# When the number of token is 1, per-token scale is [[1]]
# When per-tensor scale is [1] or ().
per_tensor_weights = (weight_scale.numel() == 1) and weight_scale.dim() < 2
per_tensor_weights = weight_scale.numel() == 1
per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
# TODO(luka) do this dispatch during init (after ScaledMM refactor)

View File

@@ -4,13 +4,10 @@
import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp
from .common import apply_rotary_emb_torch
from .rocm_aiter_rope_ops import (
is_rocm_triton_rotary_embedding_enabled,
rocm_aiter_rotary_emb,
)
@CustomOp.register("rotary_embedding")
@@ -48,8 +45,8 @@ class RotaryEmbeddingBase(CustomOp):
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
self.is_rocm_triton_rotary_embedding_enabled = (
is_rocm_triton_rotary_embedding_enabled()
self.is_rocm_triton_rotary_embed_enabled = (
rocm_aiter_ops.is_triton_rotary_embed_enabled()
)
def _compute_inv_freq(self, base: float) -> torch.Tensor:
@@ -169,9 +166,9 @@ class RotaryEmbedding(RotaryEmbeddingBase):
query: torch.Tensor,
key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if self.is_rocm_triton_rotary_embedding_enabled:
if self.is_rocm_triton_rotary_embed_enabled:
self._match_cos_sin_cache_dtype(query)
rocm_aiter_rotary_emb(
rocm_aiter_ops.triton_rotary_embed(
positions,
query,
key,

View File

@@ -146,6 +146,15 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
key = key_rot
return query, key
def forward_hip(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(positions, query, key, offsets)
def forward_cuda(
self,
positions: torch.Tensor,

View File

@@ -1,94 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
def is_rocm_triton_rotary_embedding_enabled() -> bool:
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_TRITON_ROPE
)
def rocm_aiter_rotary_emb_with_key_forward_triton_impl(
positions: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
rotate_style: int = 0,
is_nope_first: bool = False,
) -> None:
import aiter.ops.triton.rope as ops
ops.rope_cached_thd_positions_2c_fwd_inplace(
query,
key,
cos,
sin,
positions,
rotate_style,
reuse_freqs_front_part=True,
nope_first=is_nope_first,
)
def rocm_aiter_rotary_emb_with_key_forward_triton_fake(
positions: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
rotate_style: int = 0,
is_nope_first: bool = False,
) -> None:
pass
if is_rocm_triton_rotary_embedding_enabled():
direct_register_custom_op(
op_name="rocm_aiter_rotary_emb_with_key_forward_triton",
op_func=rocm_aiter_rotary_emb_with_key_forward_triton_impl,
mutates_args=["key", "query"],
fake_impl=rocm_aiter_rotary_emb_with_key_forward_triton_fake,
dispatch_key=current_platform.dispatch_key,
)
def rocm_aiter_rotary_emb(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
cos_sin_cache: torch.Tensor,
head_size: int,
rotary_dim: int,
is_neox_style: bool,
):
num_tokens = positions.numel()
cos, sin = cos_sin_cache.chunk(2, dim=-1)
query_shape = query.shape
key_shape = key.shape
rotate_style = 0 if is_neox_style else 1
query = query.view(num_tokens, -1, head_size)
key = key.view(num_tokens, -1, head_size)
query_ = query[..., :rotary_dim]
key_ = key[..., :rotary_dim]
positions = positions.view(*query.shape[:1])
torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_triton(
positions,
sin,
cos,
query_,
key_,
rotate_style,
False,
)
query = query.view(query_shape)
key = key.view(key_shape)