[ROCm] Add aiter tkw1 kernel for Llama4 fp8 (#16727)

Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
kliuae
2025-04-22 11:42:34 +08:00
committed by GitHub
parent 0e4254492f
commit 5b794cae8d
6 changed files with 134 additions and 48 deletions

View File

@@ -250,6 +250,28 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
# Property to determine if AITER is used
if is_rocm_aiter_moe_enabled():
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
rocm_aiter_fused_experts, shuffle_weights)
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data)
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False)
self.fused_experts_func = rocm_aiter_fused_experts
else:
from vllm.model_executor.layers.fused_moe import fused_experts
self.fused_experts_func = fused_experts
def apply(
self,
layer: torch.nn.Module,
@@ -268,7 +290,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
@@ -282,7 +303,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(
return self.fused_experts_func(
x,
layer.w13_weight,
layer.w2_weight,

View File

@@ -575,8 +575,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer: Module) -> None:
# Lazy import to avoid importing triton too early.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
expand_weights, is_rocm_aiter_block_scaled_moe_enabled,
is_rocm_aiter_moe_enabled, shuffle_weights)
expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights)
# TODO (rob): refactor block quant into separate class.
if self.block_quant:
@@ -603,7 +602,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
requires_grad=False)
if is_rocm_aiter_block_scaled_moe_enabled():
if is_rocm_aiter_moe_enabled():
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data)