[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Fused MoE kernel."""
|
||||
"""Fused MoE Triton kernels."""
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
# torch.compile needs typing.List. It will fail torch.library.infer_schema
|
||||
# otherwise
|
||||
from typing import List # noqa: UP035
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -18,7 +18,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
# yapf: disable
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, get_config_quant_dtype)
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig, _get_config_dtype_str)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
_valid_cutlass_block_scaled_grouped_gemm,
|
||||
run_cutlass_block_scaled_fused_experts)
|
||||
@@ -32,11 +32,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache, moe_kernel_quantize_input)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
calculate_tile_tokens_dim)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
_resize_cache, activation_without_mul, moe_kernel_quantize_input)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
dequant_mxfp4)
|
||||
from vllm.platforms import current_platform
|
||||
@@ -1049,87 +1045,66 @@ def fused_grouped_topk(
|
||||
return topk_values.to(torch.float32), topk_indices.to(torch.int32)
|
||||
|
||||
|
||||
def get_config_dtype_str(
|
||||
dtype: torch.dtype,
|
||||
use_int4_w4a16: Optional[bool] = False,
|
||||
use_int8_w8a16: Optional[bool] = False,
|
||||
use_fp8_w8a8: Optional[bool] = False,
|
||||
use_mxfp4_w4a4: Optional[bool] = False) -> Optional[str]:
|
||||
if use_fp8_w8a8:
|
||||
return "fp8_w8a8"
|
||||
elif use_int8_w8a16:
|
||||
return "int8_w8a16"
|
||||
elif use_int4_w4a16:
|
||||
return "int4_w4a16"
|
||||
elif use_mxfp4_w4a4:
|
||||
return "mxfp4_w4a4"
|
||||
elif dtype == torch.float:
|
||||
# avoiding cases where kernel fails when float32 MoE
|
||||
# use fp16/bfloat16 configs
|
||||
return "float32"
|
||||
return None
|
||||
|
||||
|
||||
def inplace_fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None, #noqa: UP006
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> None:
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None, #noqa: UP006
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||
activation, is_act_and_mul,
|
||||
apply_router_weight_on_input, use_fp8_w8a8,
|
||||
activation, apply_router_weight_on_input, use_fp8_w8a8,
|
||||
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
|
||||
use_mxfp4_w4a4, per_channel_quant, global_num_experts,
|
||||
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
|
||||
a2_scale, block_shape, w1_bias, w2_bias)
|
||||
|
||||
|
||||
def inplace_fused_experts_fake(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> None:
|
||||
def inplace_fused_experts_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None, #noqa: UP006
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@@ -1143,175 +1118,6 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w13_weight: torch.Tensor,
|
||||
w13_weight_scale_inv: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
w2_weight_scale_inv: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
intermediate_size: int,
|
||||
expert_offset: int,
|
||||
local_num_experts: int,
|
||||
block_shape: List[int], #noqa: UP006
|
||||
routed_scaling: float = 1.0) -> torch.Tensor:
|
||||
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
|
||||
assert top_k <= global_num_experts
|
||||
assert top_k <= 8
|
||||
assert topk_group <= 4
|
||||
assert global_num_experts > num_expert_group
|
||||
assert global_num_experts % num_expert_group == 0
|
||||
assert global_num_experts % 4 == 0
|
||||
assert top_k < (topk_group * global_num_experts / num_expert_group)
|
||||
assert block_shape == [128, 128]
|
||||
|
||||
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
|
||||
# NOTE: scales of hidden states have to be transposed!
|
||||
a_sf_t = a_sf.t().contiguous()
|
||||
return flashinfer_trtllm_fp8_block_scale_moe(
|
||||
routing_logits=routing_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=a_q,
|
||||
hidden_states_scale=a_sf_t,
|
||||
gemm1_weights=w13_weight,
|
||||
gemm1_weights_scale=w13_weight_scale_inv,
|
||||
gemm2_weights=w2_weight,
|
||||
gemm2_weights_scale=w2_weight_scale_inv,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=intermediate_size,
|
||||
local_expert_offset=expert_offset,
|
||||
local_num_experts=local_num_experts,
|
||||
routed_scaling_factor=routed_scaling,
|
||||
tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k,
|
||||
global_num_experts),
|
||||
routing_method_type=2, # DeepSeek-styled routing method
|
||||
use_shuffled_weight=False,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_blockscale_fp8_fake(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w13_weight: torch.Tensor,
|
||||
w13_weight_scale_inv: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
w2_weight_scale_inv: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
intermediate_size: int,
|
||||
expert_offset: int,
|
||||
local_num_experts: int,
|
||||
block_shape: list[int],
|
||||
routed_scaling: float = 1.0) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_fused_moe_blockscale_fp8",
|
||||
op_func=flashinfer_fused_moe_blockscale_fp8,
|
||||
mutates_args=[],
|
||||
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_per_tensor_scale_fp8(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: Optional[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
output1_scales_scalar: torch.Tensor,
|
||||
output1_scales_gate_scalar: torch.Tensor,
|
||||
output2_scales_scalar: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
use_routing_scales_on_input: bool,
|
||||
routing_method_type: int,
|
||||
routed_scaling_factor: float = 1.0) -> torch.Tensor:
|
||||
num_expert_group = num_expert_group if num_expert_group is not None else 0
|
||||
topk_group = topk_group if topk_group is not None else 0
|
||||
|
||||
quant_hidden_states, _ = moe_kernel_quantize_input(
|
||||
hidden_states,
|
||||
input_scale,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token_quant=False)
|
||||
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_trtllm_fp8_per_tensor_scale_moe)
|
||||
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
|
||||
routing_logits=routing_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=quant_hidden_states,
|
||||
gemm1_weights=gemm1_weights,
|
||||
output1_scales_scalar=output1_scales_scalar,
|
||||
output1_scales_gate_scalar=output1_scales_gate_scalar,
|
||||
gemm2_weights=gemm2_weights,
|
||||
output2_scales_scalar=output2_scales_scalar,
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=intermediate_size,
|
||||
local_expert_offset=local_expert_offset,
|
||||
local_num_experts=local_num_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
use_routing_scales_on_input=use_routing_scales_on_input,
|
||||
tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0],
|
||||
top_k, num_experts),
|
||||
routing_method_type=routing_method_type)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: Optional[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
output1_scales_scalar: torch.Tensor,
|
||||
output1_scales_gate_scalar: torch.Tensor,
|
||||
output2_scales_scalar: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
use_routing_scales_on_input: bool,
|
||||
routing_method_type: int,
|
||||
routed_scaling_factor: float = 1.0) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_fused_moe_per_tensor_scale_fp8",
|
||||
op_func=flashinfer_fused_moe_per_tensor_scale_fp8,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
def outplace_fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -1319,7 +1125,6 @@ def outplace_fused_experts(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
@@ -1341,37 +1146,37 @@ def outplace_fused_experts(
|
||||
) -> torch.Tensor:
|
||||
return fused_experts_impl(
|
||||
hidden_states, w1, w2, topk_weights, topk_ids, False, activation,
|
||||
is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8,
|
||||
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4,
|
||||
per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale,
|
||||
w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w1_bias, w2_bias)
|
||||
apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8,
|
||||
use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, per_channel_quant,
|
||||
global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp,
|
||||
a1_scale, a2_scale, block_shape, w1_bias, w2_bias)
|
||||
|
||||
|
||||
def outplace_fused_experts_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
@@ -1403,45 +1208,36 @@ def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
|
||||
|
||||
# TODO (bnell): replace this with modular op. Can get rid of inplace/outplace
|
||||
# torch ops.
|
||||
def fused_experts(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
allow_deep_gemm: bool = False,
|
||||
allow_cutlass_block_scaled_grouped_gemm: bool = False,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
quant_config: Optional[FusedMoEQuantConfig] = None,
|
||||
allow_deep_gemm: bool = False,
|
||||
allow_cutlass_block_scaled_grouped_gemm: bool = False,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if quant_config is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
use_fp8_w8a8 = quant_config.use_fp8_w8a8
|
||||
|
||||
# For now, disable DeepGemm for small N (<= 512) until better
|
||||
# permute/unpermute ops are available.
|
||||
# However, on B200, we use DeepGemm for all cases because they only support
|
||||
# E8M0 scale, which means we requantize the weight and input to the specific
|
||||
# scale. Fallen back to cutlass or triton for some cases would cause
|
||||
# accuracy issue.
|
||||
if (allow_deep_gemm and use_fp8_w8a8 and
|
||||
if (allow_deep_gemm and quant_config.use_fp8_w8a8 and
|
||||
(is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))):
|
||||
assert quant_config is not None
|
||||
assert apply_router_weight_on_input is False
|
||||
assert is_act_and_mul, (
|
||||
"DeepGemm only supports is_act_and_mul=True for now.")
|
||||
return deep_gemm_moe_fp8(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
@@ -1452,22 +1248,23 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
w1_scale=quant_config.w1_scale,
|
||||
w2_scale=quant_config.w2_scale,
|
||||
a1_scale=quant_config.a1_scale,
|
||||
a2_scale=quant_config.a2_scale,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8
|
||||
and _valid_cutlass_block_scaled_grouped_gemm(
|
||||
w1, w2, inplace, activation, apply_router_weight_on_input,
|
||||
expert_map)):
|
||||
assert quant_config is not None
|
||||
return run_cutlass_block_scaled_fused_experts(
|
||||
a=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_scale=quant_config.w1_scale,
|
||||
w2_scale=quant_config.w2_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids)
|
||||
else:
|
||||
@@ -1478,26 +1275,49 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
is_act_and_mul=is_act_and_mul,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
per_channel_quant=per_channel_quant,
|
||||
use_fp8_w8a8=quant_config.use_fp8_w8a8,
|
||||
use_int8_w8a8=quant_config.use_int8_w8a8,
|
||||
use_int8_w8a16=quant_config.use_int8_w8a16,
|
||||
use_int4_w4a16=quant_config.use_int4_w4a16,
|
||||
use_mxfp4_w4a4=quant_config.use_mxfp4_w4a4,
|
||||
per_channel_quant=quant_config.per_act_token_quant,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
)
|
||||
w1_scale=quant_config.w1_scale,
|
||||
w2_scale=quant_config.w2_scale,
|
||||
w1_zp=quant_config.w1_zp,
|
||||
w2_zp=quant_config.w2_zp,
|
||||
a1_scale=quant_config.a1_scale,
|
||||
a2_scale=quant_config.a2_scale,
|
||||
block_shape=quant_config.block_shape,
|
||||
w1_bias=quant_config.w1_bias,
|
||||
w2_bias=quant_config.w2_bias)
|
||||
|
||||
|
||||
SILU_NO_MUL: str = activation_without_mul("silu")
|
||||
GELU_NO_MUL: str = activation_without_mul("gelu")
|
||||
|
||||
|
||||
def _get_config_quant_dtype(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_mxfp4_w4a4: bool,
|
||||
) -> Union[None, torch.dtype, str]:
|
||||
"""
|
||||
Get the quantization type based on the quantization strategy flags.
|
||||
We don't have a quant_config at this point so we need to work backwards.
|
||||
A return type of None means no quantization is required because the
|
||||
input is unquantized or has been quantized prior to calling
|
||||
fused_experts_impl.
|
||||
"""
|
||||
if use_fp8_w8a8:
|
||||
return torch.float8_e4m3fn
|
||||
elif use_int8_w8a8:
|
||||
return torch.int8
|
||||
elif use_mxfp4_w4a4:
|
||||
return "mxfp4"
|
||||
return None
|
||||
|
||||
|
||||
def fused_experts_impl(
|
||||
@@ -1508,7 +1328,6 @@ def fused_experts_impl(
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
@@ -1557,17 +1376,18 @@ def fused_experts_impl(
|
||||
# https://github.com/vllm-project/vllm/issues/5938
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4)
|
||||
config_dtype = _get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
# Note: for use_int8_w8a16 or use_int4_w4a16, the activations are
|
||||
# quantized prior to calling fused_experts.
|
||||
quant_dtype = _get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4)
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
@@ -1640,7 +1460,7 @@ def fused_experts_impl(
|
||||
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
|
||||
A=curr_hidden_states,
|
||||
A_scale=a1_scale,
|
||||
quant_dtype=qtype,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
|
||||
@@ -1671,30 +1491,29 @@ def fused_experts_impl(
|
||||
B_bias=w1_bias)
|
||||
|
||||
# Activation function with multiplication
|
||||
if activation == "silu" and is_act_and_mul:
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
elif activation == "gelu" and is_act_and_mul:
|
||||
elif activation == "gelu":
|
||||
torch.ops._C.gelu_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
elif activation == "swigluoai" and is_act_and_mul:
|
||||
elif activation == "swigluoai":
|
||||
# alpha = 1.702, limit = 7.0
|
||||
torch.ops._C.swigluoai_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
# Activation function without multiplication
|
||||
elif activation == "silu":
|
||||
elif activation == SILU_NO_MUL:
|
||||
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
|
||||
elif activation == "gelu":
|
||||
elif activation == GELU_NO_MUL:
|
||||
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}, "
|
||||
f"with is_act_and_mul={is_act_and_mul}.")
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
|
||||
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
A=intermediate_cache2,
|
||||
A_scale=a2_scale,
|
||||
quant_dtype=qtype,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
|
||||
@@ -1726,164 +1545,13 @@ def fused_experts_impl(
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
weights, w1 and w2, and top-k gating mechanism.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- topk (int): The number of top-k experts to select.
|
||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- activation (str): The activation function to apply after the first
|
||||
MoE layer.
|
||||
- is_act_and_mul (bool): If True, use activation-and-mul function for
|
||||
activation (self-gated activation), otherwise use activation function
|
||||
for activation (ungated activation).
|
||||
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
||||
- topk_group: Optional[int]: additional parameter for grouped_topk
|
||||
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
||||
note: Deepseekv2 model uses grouped_topk
|
||||
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
|
||||
activation to compute the inner products for w1 and w2.
|
||||
Defaults to False.
|
||||
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
|
||||
activation to compute the inner products for w1 and w2.
|
||||
Defaults to False.
|
||||
- use_mxfp4_w4a4 (bool): If True, use matmul of OCP MXFP4 weight and
|
||||
OCP MXFP4 activation to compute the inner products for w1 and w2.
|
||||
Defaults to False.
|
||||
- global_num_experts (int): The total number of experts in the global
|
||||
expert space.
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w2.
|
||||
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
a1.
|
||||
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
a2.
|
||||
- block_shape: (Optional[list[int]]): Optional block size for block-wise
|
||||
quantization.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
if not is_act_and_mul:
|
||||
assert inplace is False, (
|
||||
"is_act_and_mul=False is not supported with inplace=True")
|
||||
|
||||
if use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
|
||||
topk, renormalize,
|
||||
num_expert_group, topk_group)
|
||||
elif custom_routing_function is None:
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states, gating_output, topk, renormalize)
|
||||
else:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states, gating_output, topk, renormalize)
|
||||
|
||||
return fused_experts(hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
is_act_and_mul=is_act_and_mul,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
per_channel_quant=per_channel_quant,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias)
|
||||
|
||||
|
||||
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig.make(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
))
|
||||
|
||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||
self.use_int4_w4a16 = use_int4_w4a16
|
||||
self.use_int8_w8a8 = use_int8_w8a8
|
||||
self.use_int8_w8a16 = use_int8_w8a16
|
||||
self.use_mxfp4_w4a4 = use_mxfp4_w4a4
|
||||
super().__init__(quant_config)
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
@@ -1929,19 +1597,14 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
# Check constraints.
|
||||
if self.use_int4_w4a16:
|
||||
if self.quant_config.use_int4_w4a16:
|
||||
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
||||
"Hidden size mismatch")
|
||||
else:
|
||||
@@ -1964,17 +1627,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
use_mxfp4_w4a4=self.use_mxfp4_w4a4,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
config = try_get_optimal_moe_config(
|
||||
w1.size(),
|
||||
w2.size(),
|
||||
top_k_num,
|
||||
config_dtype,
|
||||
self.quant_config.config_name(hidden_states.dtype),
|
||||
num_tokens,
|
||||
block_shape=self.block_shape,
|
||||
)
|
||||
@@ -2008,8 +1665,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1q_scale,
|
||||
w1_scale,
|
||||
w1_zp,
|
||||
self.w1_scale,
|
||||
self.w1_zp,
|
||||
None, # topk_weights
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
@@ -2018,13 +1675,13 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
top_k_num,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a8=self.use_int8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
|
||||
use_int8_w8a8=self.quant_config.use_int8_w8a8,
|
||||
use_int8_w8a16=self.quant_config.use_int8_w8a16,
|
||||
use_int4_w4a16=self.quant_config.use_int4_w4a16,
|
||||
per_channel_quant=self.per_act_token_quant,
|
||||
block_shape=self.block_shape,
|
||||
B_bias=None # TODO support B_bias
|
||||
B_bias=self.w1_bias,
|
||||
)
|
||||
|
||||
self.activation(activation, intermediate_cache2,
|
||||
@@ -2033,7 +1690,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
intermediate_cache2, a2_scale, self.quant_dtype,
|
||||
intermediate_cache2, self.a2_scale, self.quant_dtype,
|
||||
self.per_act_token_quant, self.block_shape)
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
@@ -2041,8 +1698,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2q_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
self.w2_scale,
|
||||
self.w2_zp,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
@@ -2051,36 +1708,21 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a8=self.use_int8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
|
||||
use_int8_w8a8=self.quant_config.use_int8_w8a8,
|
||||
use_int8_w8a16=self.quant_config.use_int8_w8a16,
|
||||
use_int4_w4a16=self.quant_config.use_int4_w4a16,
|
||||
per_channel_quant=self.per_act_token_quant,
|
||||
block_shape=self.block_shape,
|
||||
B_bias=None # TODO support B_bias
|
||||
B_bias=self.w2_bias,
|
||||
)
|
||||
|
||||
ops.moe_sum(intermediate_cache3, output)
|
||||
|
||||
|
||||
def modular_triton_fused_moe(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
use_mxfp4_w4a4: bool,
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
quant_config: FusedMoEQuantConfig) -> mk.FusedMoEModularKernel:
|
||||
return mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonExperts(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
),
|
||||
TritonExperts(quant_config),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user