refactor: Change scaling factors calculation for flashinfer FusedMoE (#22812)
Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -1189,10 +1189,10 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
|
||||
hidden_states: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm1_weights_scale: torch.Tensor,
|
||||
activation_scale: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
gemm2_weights_scale: torch.Tensor,
|
||||
output1_scales_scalar: torch.Tensor,
|
||||
output1_scales_gate_scalar: torch.Tensor,
|
||||
output2_scales_scalar: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: Optional[int],
|
||||
@@ -1206,17 +1206,12 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
|
||||
num_expert_group = num_expert_group if num_expert_group is not None else 0
|
||||
topk_group = topk_group if topk_group is not None else 0
|
||||
|
||||
quant_hidden_states, input_scale = moe_kernel_quantize_input(
|
||||
quant_hidden_states, _ = moe_kernel_quantize_input(
|
||||
hidden_states,
|
||||
input_scale,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token_quant=False)
|
||||
|
||||
output1_scales_scalar = gemm1_weights_scale * input_scale * (
|
||||
1.0 / activation_scale)
|
||||
output1_scales_gate_scalar = gemm1_weights_scale * input_scale
|
||||
output2_scales_scalar = activation_scale * gemm2_weights_scale
|
||||
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_trtllm_fp8_per_tensor_scale_moe)
|
||||
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
|
||||
@@ -1244,24 +1239,24 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
|
||||
|
||||
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor,
|
||||
routing_bias: Optional[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
output1_scales_scalar: torch.Tensor,
|
||||
output1_scales_gate_scalar: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
output2_scales_scalar: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
use_routing_scales_on_input: bool = False,
|
||||
tile_tokens_dim: int = 8,
|
||||
routing_method_type: int = 0) -> torch.Tensor:
|
||||
use_routing_scales_on_input: bool,
|
||||
routing_method_type: int,
|
||||
routed_scaling_factor: float = 1.0) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user