[MoE Refactor][7/N] AITER MK (#31102)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw
2025-12-22 18:42:58 -05:00
committed by GitHub
parent 6d518ffbaa
commit b57b967386
4 changed files with 144 additions and 66 deletions

View File

@@ -117,6 +117,7 @@ class Fp8MoeBackend(Enum):
DEEPGEMM = 3
MARLIN = 4
TRITON = 5
AITER = 6
def get_fp8_moe_backend(
@@ -189,6 +190,10 @@ def get_fp8_moe_backend(
logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
return Fp8MoeBackend.DEEPGEMM
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
logger.info_once("Using ROCm AITER backend for FP8 MoE", scope="local")
return Fp8MoeBackend.AITER
# default to Triton
logger.info_once("Using Triton backend for FP8 MoE")
return Fp8MoeBackend.TRITON
@@ -888,16 +893,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_input_scale = None
layer.w2_input_scale = None
self.rocm_aiter_moe_enabled = False
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# Lazy import to avoid importing triton too early.
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# TODO (rob): refactor block quant into separate class.
if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic"
@@ -932,7 +931,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv)
replace_parameter(layer, "w2_weight", w2_weight)
replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv)
if self.rocm_aiter_moe_enabled:
if self.fp8_backend == Fp8MoeBackend.AITER:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
@@ -1026,7 +1025,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
start += shard_size
if self.rocm_aiter_moe_enabled:
if self.fp8_backend == Fp8MoeBackend.AITER:
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)
@@ -1072,6 +1071,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.moe_quant_config = config
self.kernel = mk.FusedMoEModularKernel(
# TODO(rob): we can use the generic MoEPrepareAndFinalizeNoEP
# with the changes to defer input quantization
FlashInferAllGatherMoEPrepareAndFinalize(
use_dp=(self.moe.dp_size > 1),
use_deepseek_fp8_block_scale=self.block_quant,
@@ -1093,6 +1094,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
Fp8MoeBackend.DEEPGEMM,
Fp8MoeBackend.TRITON,
Fp8MoeBackend.MARLIN,
Fp8MoeBackend.AITER,
]:
from vllm.model_executor.layers.fused_moe import (
TritonOrDeepGemmExperts,
@@ -1103,24 +1105,33 @@ class Fp8MoEMethod(FusedMoEMethodBase):
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
config = self.get_fused_moe_quant_config(layer)
assert config is not None
self.moe_quant_config = config
use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
moe_kernel = (
MarlinExperts(quant_config=self.moe_quant_config)
if use_marlin
else TritonOrDeepGemmExperts(
quant_config=self.moe_quant_config,
allow_deep_gemm=allow_deep_gemm,
)
)
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), moe_kernel
)
if self.fp8_backend == Fp8MoeBackend.AITER:
self.kernel = mk.FusedMoEModularKernel(
# TODO: make defer_input_quant an attr of the AiterExperts
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
AiterExperts(quant_config=self.moe_quant_config),
)
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
MarlinExperts(quant_config=self.moe_quant_config),
)
else:
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonOrDeepGemmExperts(
quant_config=self.moe_quant_config,
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
),
)
self.use_inplace = True
def maybe_make_prepare_finalize(
@@ -1128,7 +1139,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if (
self.rocm_aiter_moe_enabled
self.fp8_backend == Fp8MoeBackend.AITER
or self.fp8_backend == Fp8MoeBackend.MARLIN
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
@@ -1161,11 +1172,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
TritonOrDeepGemmExperts,
)
assert (
self.fp8_backend != Fp8MoeBackend.MARLIN
) and not self.rocm_aiter_moe_enabled, (
"Marlin and ROCm AITER are not supported with all2all yet."
)
if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
raise NotImplementedError(
"Marlin and ROCm AITER are not supported with all2all yet."
)
assert self.moe_quant_config is not None
@@ -1313,37 +1323,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
hidden_states=x,
router_logits=router_logits,
)
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_fused_experts,
)
# TODO(rob): convert this to MK.
result = rocm_aiter_fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
else:
result = self.kernel(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=self.use_inplace,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
result = self.kernel(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=self.use_inplace,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
return result
@@ -1456,15 +1447,10 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
layer.w13_input_scale = None
layer.w2_input_scale = None
self.rocm_aiter_moe_enabled = False
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# Lazy import to avoid importing triton too early.
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# If checkpoint is fp16, quantize in place.
fp8_dtype = current_platform.fp8_dtype()
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
@@ -1481,7 +1467,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
replace_parameter(layer, "w2_weight", w2_weight)
# Reshuffle weights for AITER if needed.
if self.rocm_aiter_moe_enabled:
if self.fp8_backend == Fp8MoeBackend.AITER:
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)
@@ -1489,7 +1475,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
replace_parameter(layer, "w2_weight", shuffled_w2)
# Rushuffle weights for MARLIN if needed.
if self.fp8_backend == Fp8MoeBackend.MARLIN:
elif self.fp8_backend == Fp8MoeBackend.MARLIN:
prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
)