[NVIDIA] Add SM100 Flashinfer MoE per tensor scale fp8 backend (#21458)
Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -30,6 +30,8 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
calculate_tile_tokens_dim)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
dequant_mxfp4)
|
||||
from vllm.platforms import current_platform
|
||||
@@ -1065,22 +1067,6 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def next_positive_power_of_2(x: int) -> int:
|
||||
if x < 1:
|
||||
return 1
|
||||
return 1 << (x - 1).bit_length()
|
||||
|
||||
|
||||
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||
# Guess tokens per expert assuming perfect expert distribution first.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
return tile_tokens_dim
|
||||
|
||||
|
||||
def flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor,
|
||||
@@ -1128,8 +1114,8 @@ def flashinfer_fused_moe_blockscale_fp8(
|
||||
local_expert_offset=expert_offset,
|
||||
local_num_experts=local_num_experts,
|
||||
routed_scaling_factor=routed_scaling,
|
||||
tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
|
||||
global_num_experts),
|
||||
tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k,
|
||||
global_num_experts),
|
||||
routing_method_type=2, # DeepSeek-styled routing method
|
||||
use_shuffled_weight=False,
|
||||
)
|
||||
@@ -1164,6 +1150,97 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_per_tensor_scale_fp8(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: Optional[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm1_weights_scale: torch.Tensor,
|
||||
activation_scale: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
gemm2_weights_scale: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
use_routing_scales_on_input: bool,
|
||||
routing_method_type: int,
|
||||
routed_scaling_factor: float = 1.0) -> torch.Tensor:
|
||||
num_expert_group = num_expert_group if num_expert_group is not None else 0
|
||||
topk_group = topk_group if topk_group is not None else 0
|
||||
|
||||
quant_hidden_states, input_scale = moe_kernel_quantize_input(
|
||||
hidden_states,
|
||||
input_scale,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token_quant=False)
|
||||
|
||||
output1_scales_scalar = gemm1_weights_scale * input_scale * (
|
||||
1.0 / activation_scale)
|
||||
output1_scales_gate_scalar = gemm1_weights_scale * input_scale
|
||||
output2_scales_scalar = activation_scale * gemm2_weights_scale
|
||||
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_trtllm_fp8_per_tensor_scale_moe)
|
||||
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
|
||||
routing_logits=routing_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=quant_hidden_states,
|
||||
gemm1_weights=gemm1_weights,
|
||||
output1_scales_scalar=output1_scales_scalar,
|
||||
output1_scales_gate_scalar=output1_scales_gate_scalar,
|
||||
gemm2_weights=gemm2_weights,
|
||||
output2_scales_scalar=output2_scales_scalar,
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=intermediate_size,
|
||||
local_expert_offset=local_expert_offset,
|
||||
local_num_experts=local_num_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
use_routing_scales_on_input=use_routing_scales_on_input,
|
||||
tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0],
|
||||
top_k, num_experts),
|
||||
routing_method_type=routing_method_type)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
output1_scales_scalar: torch.Tensor,
|
||||
output1_scales_gate_scalar: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
output2_scales_scalar: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
use_routing_scales_on_input: bool = False,
|
||||
tile_tokens_dim: int = 8,
|
||||
routing_method_type: int = 0) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_fused_moe_per_tensor_scale_fp8",
|
||||
op_func=flashinfer_fused_moe_per_tensor_scale_fp8,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
def outplace_fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user