[FEAT] [ROCm]: Add AITER CK 2 Stages MoE support (#17110)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
TJian
2025-05-14 18:03:11 +08:00
committed by GitHub
parent 38fe728d60
commit 612c2edb4f
7 changed files with 201 additions and 112 deletions

View File

@@ -125,6 +125,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
@@ -276,24 +280,22 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
# Property to determine if AITER is used
if is_rocm_aiter_moe_enabled():
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
rocm_aiter_fused_experts, shuffle_weights)
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data)
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data,
layer.w2_weight.data,
layout=(16, 16))
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False)
self.fused_experts_func = rocm_aiter_fused_experts
self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts
else:
from vllm.model_executor.layers.fused_moe import fused_experts
self.fused_experts_func = fused_experts
@@ -335,6 +337,22 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
if self.rocm_aiter_moe_enabled:
return self.rocm_aiter_fused_experts_func(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy ==
QuantizationStrategy.CHANNEL,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
if self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")

View File

@@ -591,6 +591,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights)
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
# TODO (rob): refactor block quant into separate class.
if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic"
@@ -616,10 +618,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
requires_grad=False)
if is_rocm_aiter_moe_enabled():
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data)
layer.w13_weight.data,
layer.w2_weight.data,
layout=(16, 16))
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
@@ -663,7 +667,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False)
if is_rocm_aiter_moe_enabled():
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
w13_scales, w2_scales = expand_weights(
layer.w13_weight_scale.data,
@@ -676,8 +680,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale = torch.nn.Parameter(
w2_scales.contiguous(), requires_grad=False)
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight, layer.w2_weight)
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight,
layer.w2_weight,
layout=(16, 16))
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
@@ -748,7 +753,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
dq_weight, max_w13_scales[expert_id])
start += shard_size
if is_rocm_aiter_moe_enabled():
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
expansion_dims = [
layer.w13_weight.shape[1], layer.w2_weight.shape[1]
@@ -760,8 +765,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale = torch.nn.Parameter(
w2_scales.contiguous(), requires_grad=False)
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight, layer.w2_weight)
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight,
layer.w2_weight,
layout=(32, 32))
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
@@ -796,6 +802,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts)
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
@@ -810,6 +818,24 @@ class Fp8MoEMethod(FusedMoEMethodBase):
e_score_correction_bias=e_score_correction_bias,
)
if self.rocm_aiter_moe_enabled:
return rocm_aiter_fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
use_fp8_w8a8=True,
apply_router_weight_on_input=apply_router_weight_on_input,
w1_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size)
if self.use_marlin:
assert activation == "silu", (
f"{activation} not supported for Marlin MoE.")