refactor: Change scaling factors calculation for flashinfer FusedMoE (#22812)

Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
amirkl94
2025-08-15 09:19:31 +03:00
committed by GitHub
parent 0fe85087a9
commit b4cef5e6c7
4 changed files with 60 additions and 25 deletions

View File

@@ -1189,10 +1189,10 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm1_weights_scale: torch.Tensor,
activation_scale: torch.Tensor,
gemm2_weights: torch.Tensor,
gemm2_weights_scale: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: Optional[int],
@@ -1206,17 +1206,12 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
num_expert_group = num_expert_group if num_expert_group is not None else 0
topk_group = topk_group if topk_group is not None else 0
quant_hidden_states, input_scale = moe_kernel_quantize_input(
quant_hidden_states, _ = moe_kernel_quantize_input(
hidden_states,
input_scale,
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=False)
output1_scales_scalar = gemm1_weights_scale * input_scale * (
1.0 / activation_scale)
output1_scales_gate_scalar = gemm1_weights_scale * input_scale
output2_scales_scalar = activation_scale * gemm2_weights_scale
from vllm.utils.flashinfer import (
flashinfer_trtllm_fp8_per_tensor_scale_moe)
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
@@ -1244,24 +1239,24 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
gemm2_weights: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
num_expert_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: float = 1.0,
use_routing_scales_on_input: bool = False,
tile_tokens_dim: int = 8,
routing_method_type: int = 0) -> torch.Tensor:
use_routing_scales_on_input: bool,
routing_method_type: int,
routed_scaling_factor: float = 1.0) -> torch.Tensor:
pass