[ROCm] Add aiter tkw1 kernel for Llama4 fp8 (#16727)
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
@@ -250,6 +250,28 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||
requires_grad=False)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled)
|
||||
|
||||
# Property to determine if AITER is used
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
|
||||
rocm_aiter_fused_experts, shuffle_weights)
|
||||
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||
requires_grad=False)
|
||||
|
||||
self.fused_experts_func = rocm_aiter_fused_experts
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
self.fused_experts_func = fused_experts
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -268,7 +290,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
@@ -282,7 +303,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
return fused_experts(
|
||||
return self.fused_experts_func(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
|
||||
@@ -575,8 +575,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# Lazy import to avoid importing triton too early.
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
expand_weights, is_rocm_aiter_block_scaled_moe_enabled,
|
||||
is_rocm_aiter_moe_enabled, shuffle_weights)
|
||||
expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights)
|
||||
|
||||
# TODO (rob): refactor block quant into separate class.
|
||||
if self.block_quant:
|
||||
@@ -603,7 +602,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
|
||||
requires_grad=False)
|
||||
if is_rocm_aiter_block_scaled_moe_enabled():
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data)
|
||||
|
||||
Reference in New Issue
Block a user