[FEAT] [ROCm]: Add AITER CK 2 Stages MoE support (#17110)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
@@ -125,6 +125,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled)
|
||||
|
||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
@@ -276,24 +280,22 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||
requires_grad=False)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled)
|
||||
|
||||
# Property to determine if AITER is used
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
|
||||
rocm_aiter_fused_experts, shuffle_weights)
|
||||
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data)
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data,
|
||||
layer.w2_weight.data,
|
||||
layout=(16, 16))
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||
requires_grad=False)
|
||||
|
||||
self.fused_experts_func = rocm_aiter_fused_experts
|
||||
self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
self.fused_experts_func = fused_experts
|
||||
@@ -335,6 +337,22 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
return self.rocm_aiter_fused_experts_func(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=self.weight_quant.strategy ==
|
||||
QuantizationStrategy.CHANNEL,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale)
|
||||
if self.use_marlin:
|
||||
assert activation == "silu", (
|
||||
f"{activation} not supported for Marlin MoE.")
|
||||
|
||||
@@ -591,6 +591,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights)
|
||||
|
||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||
|
||||
# TODO (rob): refactor block quant into separate class.
|
||||
if self.block_quant:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
@@ -616,10 +618,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
|
||||
requires_grad=False)
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data)
|
||||
layer.w13_weight.data,
|
||||
layer.w2_weight.data,
|
||||
layout=(16, 16))
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||
requires_grad=False)
|
||||
@@ -663,7 +667,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
||||
requires_grad=False)
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
w13_scales, w2_scales = expand_weights(
|
||||
layer.w13_weight_scale.data,
|
||||
@@ -676,8 +680,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
w2_scales.contiguous(), requires_grad=False)
|
||||
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight, layer.w2_weight)
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layout=(16, 16))
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||
requires_grad=False)
|
||||
@@ -748,7 +753,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
dq_weight, max_w13_scales[expert_id])
|
||||
start += shard_size
|
||||
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
expansion_dims = [
|
||||
layer.w13_weight.shape[1], layer.w2_weight.shape[1]
|
||||
@@ -760,8 +765,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
w2_scales.contiguous(), requires_grad=False)
|
||||
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight, layer.w2_weight)
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layout=(32, 32))
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||
requires_grad=False)
|
||||
@@ -796,6 +802,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_fused_experts)
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
@@ -810,6 +818,24 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
return rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=True,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
w1_scale=(layer.w13_weight_scale_inv
|
||||
if self.block_quant else layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale_inv
|
||||
if self.block_quant else layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size)
|
||||
|
||||
if self.use_marlin:
|
||||
assert activation == "silu", (
|
||||
f"{activation} not supported for Marlin MoE.")
|
||||
|
||||
Reference in New Issue
Block a user