[RFC][ROCm][AITER] Keep all AITER kernels in _aiter_ops class like _custom_ops and _ipex_ops (#24490)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -33,6 +33,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import DeepseekV2Config, DeepseekV3Config
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
|
||||
@@ -50,10 +51,6 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_fusion_shared_expert_enabled,
|
||||
is_rocm_aiter_moe_enabled,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
@@ -294,10 +291,8 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.physical_expert_start + self.n_local_physical_experts
|
||||
)
|
||||
|
||||
if (
|
||||
config.n_shared_experts is None
|
||||
or is_rocm_aiter_fusion_shared_expert_enabled()
|
||||
):
|
||||
self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
if config.n_shared_experts is None or self.is_rocm_aiter_moe_enabled:
|
||||
self.shared_experts = None
|
||||
else:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
@@ -330,14 +325,14 @@ class DeepseekV2MoE(nn.Module):
|
||||
# we do scaling outside, set factor to 1.0 to avoid double mul
|
||||
# aiter applies routed_scaling_factor internally
|
||||
routed_scaling_factor=1.0
|
||||
if not is_rocm_aiter_moe_enabled()
|
||||
if not self.is_rocm_aiter_moe_enabled
|
||||
else self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
n_shared_experts=config.n_shared_experts
|
||||
if is_rocm_aiter_fusion_shared_expert_enabled()
|
||||
if rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
|
||||
else None,
|
||||
)
|
||||
|
||||
@@ -371,7 +366,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
# Fix FP16 overflow
|
||||
# See DeepseekV2DecoderLayer for more details.
|
||||
if hidden_states.dtype != torch.float16:
|
||||
if not is_rocm_aiter_moe_enabled():
|
||||
if not self.is_rocm_aiter_moe_enabled:
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
elif self.shared_experts is not None:
|
||||
assert shared_output is not None
|
||||
@@ -1428,6 +1423,9 @@ class DeepseekV2ForCausalLM(
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
rocm_aiter_moe_shared_expert_enabled = (
|
||||
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
|
||||
)
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
@@ -1456,7 +1454,7 @@ class DeepseekV2ForCausalLM(
|
||||
num_experts=self.config.n_routed_experts
|
||||
+ (
|
||||
self.config.n_shared_experts
|
||||
if is_rocm_aiter_fusion_shared_expert_enabled()
|
||||
if rocm_aiter_moe_shared_expert_enabled
|
||||
else 0
|
||||
),
|
||||
num_redundant_experts=self.num_redundant_experts,
|
||||
@@ -1472,9 +1470,8 @@ class DeepseekV2ForCausalLM(
|
||||
if spec_layer is not None:
|
||||
continue # skip spec decode layers for main model
|
||||
|
||||
is_fuse_shared_experts_layer = (
|
||||
is_rocm_aiter_fusion_shared_expert_enabled()
|
||||
and ("mlp.shared_experts" in name)
|
||||
is_fuse_shared_experts_layer = rocm_aiter_moe_shared_expert_enabled and (
|
||||
"mlp.shared_experts" in name
|
||||
)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
|
||||
Reference in New Issue
Block a user