[FEAT][ROCm] Integrate Fused MoE Kernels from AITER (#14967)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -17,6 +17,10 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
|
||||
rocm_aiter_fused_experts,
|
||||
rocm_aiter_topk_softmax)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -1035,6 +1039,28 @@ def try_get_optimal_moe_config(
|
||||
return config
|
||||
|
||||
|
||||
def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool) -> tuple[torch.Tensor, ...]:
|
||||
ops.topk_softmax(
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
return rocm_aiter_topk_softmax
|
||||
return vllm_topk_softmax
|
||||
|
||||
|
||||
def fused_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
@@ -1059,17 +1085,14 @@ def fused_topk(
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
|
||||
ops.topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indicies,
|
||||
gating_output.float(), # TODO(woosuk): Optimize this.
|
||||
)
|
||||
gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
|
||||
|
||||
topk_func = dispatch_topk_func()
|
||||
topk_weights, topk_ids = topk_func(topk_weights, topk_ids,
|
||||
token_expert_indicies,
|
||||
gating_output_float, renormalize)
|
||||
|
||||
del token_expert_indicies # Not used. Will be used in the future.
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
@@ -1259,6 +1282,24 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor:
|
||||
torch.ops.vllm.inplace_fused_experts(**kwargs)
|
||||
hidden_states = kwargs['hidden_states']
|
||||
return hidden_states
|
||||
|
||||
|
||||
def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
|
||||
return torch.ops.vllm.outplace_fused_experts(**kwargs)
|
||||
|
||||
|
||||
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
return rocm_aiter_fused_experts
|
||||
if inplace:
|
||||
return torch_vllm_inplace_fused_experts
|
||||
return torch_vllm_outplace_fused_experts
|
||||
|
||||
|
||||
def fused_experts(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@@ -1278,20 +1319,25 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
|
||||
if inplace:
|
||||
torch.ops.vllm.inplace_fused_experts(
|
||||
hidden_states, w1, w2, topk_weights, topk_ids, activation,
|
||||
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
|
||||
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
|
||||
block_shape)
|
||||
return hidden_states
|
||||
else:
|
||||
return torch.ops.vllm.outplace_fused_experts(
|
||||
hidden_states, w1, w2, topk_weights, topk_ids, activation,
|
||||
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
|
||||
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
|
||||
block_shape)
|
||||
return dispatch_fused_experts_func(inplace)(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user