refactor: Change scaling factors calculation for flashinfer FusedMoE (#22812)

Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
amirkl94
2025-08-15 09:19:31 +03:00
committed by GitHub
parent 0fe85087a9
commit b4cef5e6c7
4 changed files with 60 additions and 25 deletions

View File

@@ -24,8 +24,8 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights,
swap_w13_to_w31)
apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors,
rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
@@ -694,6 +694,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w2_weight = layer.w2_weight.data
w2_weight_scale_inv = layer.w2_weight_scale_inv.data
if not self.block_quant:
register_moe_scaling_factors(layer)
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
else:
w13_weight = layer.w13_weight.data

View File

@@ -25,8 +25,8 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_kernel,
flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights,
swap_w13_to_w31)
apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors,
rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear, is_fp4_marlin_supported,
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
@@ -430,6 +430,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
layer.w2_weight)
register_moe_scaling_factors(layer)
def apply(
self,

View File

@@ -82,6 +82,12 @@ def apply_flashinfer_per_tensor_scale_fp8(
apply_router_weight_on_input: bool,
) -> torch.Tensor:
from flashinfer.fused_moe import RoutingMethodType
assert layer.output1_scales_scalar is not None, (
"Expected output1_scales_scalar to be initialized")
assert layer.output1_scales_scalar is not None, (
"Expected output1_scales_gate_scalar to be initialized")
assert layer.output1_scales_scalar is not None, (
"Expected output2_scales_scalar to be initialized")
from vllm.model_executor.models.llama4 import Llama4MoE
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \
@@ -92,10 +98,10 @@ def apply_flashinfer_per_tensor_scale_fp8(
hidden_states=hidden_states,
input_scale=layer.w13_input_scale,
gemm1_weights=layer.w13_weight,
gemm1_weights_scale=layer.w13_weight_scale,
gemm2_weights=layer.w2_weight,
gemm2_weights_scale=layer.w2_weight_scale,
activation_scale=layer.w2_input_scale,
output1_scales_scalar=layer.output1_scales_scalar,
output1_scales_gate_scalar=layer.output1_scales_gate_scalar,
output2_scales_scalar=layer.output2_scales_scalar,
num_experts=global_num_experts,
top_k=top_k,
num_expert_group=num_expert_group,
@@ -105,4 +111,36 @@ def apply_flashinfer_per_tensor_scale_fp8(
local_num_experts=layer.local_num_experts,
use_routing_scales_on_input=apply_router_weight_on_input,
routing_method_type=RoutingMethodType.Llama4,
)
)
def get_moe_scaling_factors(
input_scale: torch.Tensor,
gemm1_weights_scale: torch.Tensor,
activation_scale: torch.Tensor,
gemm2_weights_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
output1_scales_scalar = gemm1_weights_scale * input_scale * (
1.0 / activation_scale)
output1_scales_gate_scalar = gemm1_weights_scale * input_scale
output2_scales_scalar = activation_scale * gemm2_weights_scale
return output1_scales_scalar, output1_scales_gate_scalar, \
output2_scales_scalar
def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
output1_scales, output1_gate_scales, output2_scales = \
get_moe_scaling_factors(
layer.w13_input_scale, layer.w13_weight_scale,
layer.w2_input_scale, layer.w2_weight_scale
)
layer.register_parameter(
'output1_scales_scalar',
torch.nn.Parameter(output1_scales, requires_grad=False))
layer.register_parameter(
'output1_scales_gate_scalar',
torch.nn.Parameter(output1_gate_scales, requires_grad=False))
layer.register_parameter(
'output2_scales_scalar',
torch.nn.Parameter(output2_scales, requires_grad=False))