Feature: Support Relu2 in FusedMoE fp8 cutlass path (#27261)
This commit is contained in:
@@ -148,8 +148,14 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool | None,
|
||||
):
|
||||
assert activation == "silu", (
|
||||
"Only activation silu is supported in FlashInferExperts"
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
activation_str_to_value_map = {
|
||||
"silu": ActivationType.Swiglu, # This is the default
|
||||
"relu2_no_mul": ActivationType.Relu2,
|
||||
}
|
||||
assert activation in activation_str_to_value_map, (
|
||||
f"{activation=} missing from {activation_str_to_value_map.keys()=}"
|
||||
)
|
||||
|
||||
# Select quantization metadata based on FP8 format/path
|
||||
@@ -215,6 +221,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
ep_size=self.ep_size,
|
||||
ep_rank=self.ep_rank,
|
||||
output=output,
|
||||
activation_type=activation_str_to_value_map[activation],
|
||||
# Informs FlashInfer to use the block-scale decoding path when True
|
||||
use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale,
|
||||
)
|
||||
|
||||
@@ -354,12 +354,18 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
|
||||
if (
|
||||
envs.VLLM_USE_FLASHINFER_MOE_FP8
|
||||
and has_flashinfer_moe()
|
||||
and self.moe.is_act_and_mul
|
||||
):
|
||||
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
|
||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||
if (
|
||||
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
and not self.moe.is_act_and_mul
|
||||
):
|
||||
logger.info_once(
|
||||
"Non-gated MoE is not supported for min-latency mode,"
|
||||
"falling back to high-throughput mode"
|
||||
)
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
|
||||
|
||||
logger.info_once(
|
||||
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
|
||||
)
|
||||
@@ -557,10 +563,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
if self.flashinfer_moe_backend is not None:
|
||||
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
|
||||
register_moe_scaling_factors(layer)
|
||||
if self.moe.is_act_and_mul:
|
||||
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
|
||||
register_moe_scaling_factors(layer)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
@@ -570,13 +577,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
g1_alphas=(layer.w13_weight_scale * layer.w13_input_scale).squeeze(),
|
||||
g1_alphas=layer.output1_scales_gate_scalar.squeeze(),
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
g2_alphas=(layer.w2_weight_scale * layer.w2_input_scale).squeeze(),
|
||||
g2_alphas=layer.output2_scales_scalar.squeeze(),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a1_gscale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
a2_gscale=1.0 / layer.w2_input_scale,
|
||||
a2_gscale=layer.w2_input_scale_inv,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
@@ -642,9 +649,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
assert not renormalize
|
||||
assert activation == "silu", (
|
||||
f"Expected 'silu' activation but got {activation}"
|
||||
assert activation in ("silu", "relu2_no_mul"), (
|
||||
"Expected activation to be in ('silu', 'relu2_no_mul'),"
|
||||
f"but got {activation}"
|
||||
)
|
||||
return flashinfer_cutlass_moe_fp8(
|
||||
x,
|
||||
|
||||
Reference in New Issue
Block a user