[Kernels] MoE refactor (#19636)
Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
@@ -7,6 +7,7 @@ import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache
|
||||
@@ -202,26 +203,47 @@ def run_cutlass_moe_fp8(
|
||||
|
||||
|
||||
# TODO (bnell): split class batched vs. non-batched?
|
||||
# maybe remove need for passing aq to workspace_shapes
|
||||
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_experts_per_worker: int,
|
||||
out_dtype: torch.dtype,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
out_dtype: Optional[torch.dtype],
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
use_batched_format: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_out_ch_quant,
|
||||
block_shape=block_shape,
|
||||
))
|
||||
assert max_experts_per_worker > 0
|
||||
self.max_experts_per_worker = max_experts_per_worker
|
||||
self.out_dtype = out_dtype
|
||||
self.per_act_token = per_act_token
|
||||
self.per_out_ch = per_out_ch
|
||||
self.use_batched_format = use_batched_format
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
if self.use_batched_format:
|
||||
return (mk.FusedMoEActivationFormat.BatchedExperts,
|
||||
mk.FusedMoEActivationFormat.BatchedExperts)
|
||||
else:
|
||||
return (mk.FusedMoEActivationFormat.Standard,
|
||||
mk.FusedMoEActivationFormat.Standard)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return not self.use_batched_format
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return not self.use_batched_format
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
@@ -245,7 +267,8 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace1 = (M * topk, max(2 * N, K))
|
||||
workspace2 = (M * topk, N)
|
||||
output = (M * topk, K)
|
||||
return (workspace1, workspace2, output, self.out_dtype)
|
||||
return (workspace1, workspace2, output,
|
||||
self.out_dtype if self.out_dtype is not None else a.dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -270,13 +293,14 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||
activation_callable = lambda i, o: self.activation(activation, i, o)
|
||||
run_cutlass_moe_fp8(output, hidden_states, w1, w2, topk_ids,
|
||||
activation_callable, global_num_experts,
|
||||
expert_map, w1_scale, w2_scale, a1q_scale,
|
||||
a2_scale, workspace13, workspace2,
|
||||
expert_num_tokens, self.out_dtype,
|
||||
self.per_act_token, self.per_out_ch,
|
||||
self.use_batched_format)
|
||||
in_dtype = hidden_states.dtype
|
||||
run_cutlass_moe_fp8(
|
||||
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
||||
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
|
||||
a2_scale, workspace13, workspace2, expert_num_tokens,
|
||||
self.out_dtype if self.out_dtype is not None else in_dtype,
|
||||
self.per_act_token_quant, self.per_out_ch_quant,
|
||||
self.use_batched_format)
|
||||
|
||||
|
||||
def cutlass_moe_fp8(
|
||||
@@ -287,6 +311,7 @@ def cutlass_moe_fp8(
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
activation: str = "silu",
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
@@ -330,22 +355,18 @@ def cutlass_moe_fp8(
|
||||
Returns:
|
||||
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
|
||||
"""
|
||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||
per_out_ch = w1_scale.numel() != w1_q.size(0)
|
||||
|
||||
out_dtype = a.dtype
|
||||
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(
|
||||
0)
|
||||
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_channel_quant=per_act_token,
|
||||
),
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
CutlassExpertsFp8(
|
||||
max_experts_per_worker=global_num_experts,
|
||||
out_dtype=out_dtype,
|
||||
per_act_token=per_act_token,
|
||||
per_out_ch=per_out_ch,
|
||||
max_experts_per_worker=num_experts,
|
||||
out_dtype=a.dtype,
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
use_batched_format=False,
|
||||
),
|
||||
)
|
||||
@@ -358,7 +379,7 @@ def cutlass_moe_fp8(
|
||||
topk_ids,
|
||||
False,
|
||||
activation,
|
||||
global_num_experts if global_num_experts != -1 else w1_q.size(0),
|
||||
num_experts,
|
||||
expert_map,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
|
||||
Reference in New Issue
Block a user