refactor: Change scaling factors calculation for flashinfer FusedMoE (#22812)
Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -82,6 +82,12 @@ def apply_flashinfer_per_tensor_scale_fp8(
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
from flashinfer.fused_moe import RoutingMethodType
|
||||
assert layer.output1_scales_scalar is not None, (
|
||||
"Expected output1_scales_scalar to be initialized")
|
||||
assert layer.output1_scales_scalar is not None, (
|
||||
"Expected output1_scales_gate_scalar to be initialized")
|
||||
assert layer.output1_scales_scalar is not None, (
|
||||
"Expected output2_scales_scalar to be initialized")
|
||||
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \
|
||||
@@ -92,10 +98,10 @@ def apply_flashinfer_per_tensor_scale_fp8(
|
||||
hidden_states=hidden_states,
|
||||
input_scale=layer.w13_input_scale,
|
||||
gemm1_weights=layer.w13_weight,
|
||||
gemm1_weights_scale=layer.w13_weight_scale,
|
||||
gemm2_weights=layer.w2_weight,
|
||||
gemm2_weights_scale=layer.w2_weight_scale,
|
||||
activation_scale=layer.w2_input_scale,
|
||||
output1_scales_scalar=layer.output1_scales_scalar,
|
||||
output1_scales_gate_scalar=layer.output1_scales_gate_scalar,
|
||||
output2_scales_scalar=layer.output2_scales_scalar,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
num_expert_group=num_expert_group,
|
||||
@@ -105,4 +111,36 @@ def apply_flashinfer_per_tensor_scale_fp8(
|
||||
local_num_experts=layer.local_num_experts,
|
||||
use_routing_scales_on_input=apply_router_weight_on_input,
|
||||
routing_method_type=RoutingMethodType.Llama4,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_moe_scaling_factors(
|
||||
input_scale: torch.Tensor,
|
||||
gemm1_weights_scale: torch.Tensor,
|
||||
activation_scale: torch.Tensor,
|
||||
gemm2_weights_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
output1_scales_scalar = gemm1_weights_scale * input_scale * (
|
||||
1.0 / activation_scale)
|
||||
output1_scales_gate_scalar = gemm1_weights_scale * input_scale
|
||||
output2_scales_scalar = activation_scale * gemm2_weights_scale
|
||||
|
||||
return output1_scales_scalar, output1_scales_gate_scalar, \
|
||||
output2_scales_scalar
|
||||
|
||||
|
||||
def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
|
||||
output1_scales, output1_gate_scales, output2_scales = \
|
||||
get_moe_scaling_factors(
|
||||
layer.w13_input_scale, layer.w13_weight_scale,
|
||||
layer.w2_input_scale, layer.w2_weight_scale
|
||||
)
|
||||
layer.register_parameter(
|
||||
'output1_scales_scalar',
|
||||
torch.nn.Parameter(output1_scales, requires_grad=False))
|
||||
layer.register_parameter(
|
||||
'output1_scales_gate_scalar',
|
||||
torch.nn.Parameter(output1_gate_scales, requires_grad=False))
|
||||
layer.register_parameter(
|
||||
'output2_scales_scalar',
|
||||
torch.nn.Parameter(output2_scales, requires_grad=False))
|
||||
|
||||
Reference in New Issue
Block a user