[RFC][ROCm][AITER] Keep all AITER kernels in _aiter_ops class like _custom_ops and _ipex_ops (#24490)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
vllmellm
2025-11-10 17:20:53 +01:00
committed by GitHub
parent 40e2eeeb92
commit f080a83511
25 changed files with 1193 additions and 924 deletions

View File

@@ -33,6 +33,7 @@ import torch
from torch import nn
from transformers import DeepseekV2Config, DeepseekV3Config
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention import Attention
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
@@ -50,10 +51,6 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_fusion_shared_expert_enabled,
is_rocm_aiter_moe_enabled,
)
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
@@ -294,10 +291,8 @@ class DeepseekV2MoE(nn.Module):
self.physical_expert_start + self.n_local_physical_experts
)
if (
config.n_shared_experts is None
or is_rocm_aiter_fusion_shared_expert_enabled()
):
self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
if config.n_shared_experts is None or self.is_rocm_aiter_moe_enabled:
self.shared_experts = None
else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
@@ -330,14 +325,14 @@ class DeepseekV2MoE(nn.Module):
# we do scaling outside, set factor to 1.0 to avoid double mul
# aiter applies routed_scaling_factor internally
routed_scaling_factor=1.0
if not is_rocm_aiter_moe_enabled()
if not self.is_rocm_aiter_moe_enabled
else self.routed_scaling_factor,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
n_shared_experts=config.n_shared_experts
if is_rocm_aiter_fusion_shared_expert_enabled()
if rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
else None,
)
@@ -371,7 +366,7 @@ class DeepseekV2MoE(nn.Module):
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
if not is_rocm_aiter_moe_enabled():
if not self.is_rocm_aiter_moe_enabled:
final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None:
assert shared_output is not None
@@ -1428,6 +1423,9 @@ class DeepseekV2ForCausalLM(
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
rocm_aiter_moe_shared_expert_enabled = (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
)
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
@@ -1456,7 +1454,7 @@ class DeepseekV2ForCausalLM(
num_experts=self.config.n_routed_experts
+ (
self.config.n_shared_experts
if is_rocm_aiter_fusion_shared_expert_enabled()
if rocm_aiter_moe_shared_expert_enabled
else 0
),
num_redundant_experts=self.num_redundant_experts,
@@ -1472,9 +1470,8 @@ class DeepseekV2ForCausalLM(
if spec_layer is not None:
continue # skip spec decode layers for main model
is_fuse_shared_experts_layer = (
is_rocm_aiter_fusion_shared_expert_enabled()
and ("mlp.shared_experts" in name)
is_fuse_shared_experts_layer = rocm_aiter_moe_shared_expert_enabled and (
"mlp.shared_experts" in name
)
for param_name, weight_name, shard_id in stacked_params_mapping: