[Core] FlashInfer CUTLASS fused MoE backend (NVFP4) (#20037)

Signed-off-by: shuw <shuw@nvidia.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Shu Wang
2025-07-17 23:32:45 -05:00
committed by GitHub
parent b38baabcf9
commit c7d8724e78
22 changed files with 1093 additions and 269 deletions

View File

@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from math import prod
from typing import Optional, final
from typing import Any, Optional, final
import torch
@@ -150,16 +150,12 @@ class FusedMoEPrepareAndFinalize(ABC):
@abstractmethod
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int,
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
@@ -190,15 +186,11 @@ class FusedMoEPrepareAndFinalize(ABC):
raise NotImplementedError
@abstractmethod
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: TopKWeightAndReduce,
) -> None:
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
"""
Perform any combine plus apply weights and perform a reduction on the
fused experts output.
@@ -376,6 +368,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
workspace2: torch.Tensor,
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
):
"""
This function computes the intermediate result of a Mixture of Experts
@@ -460,21 +453,19 @@ class FusedMoEModularKernel(torch.nn.Module):
f"{fused_experts.__class__.__name__}."
f"{fused_experts.activation_formats[0]}")
def _do_fused_experts(self, fused_out: Optional[torch.Tensor],
a1: torch.Tensor, a1q: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
activation: str, global_num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool) -> torch.Tensor:
def _do_fused_experts(
self, fused_out: Optional[torch.Tensor], a1: torch.Tensor,
a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
activation: str, global_num_experts: int, local_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
@@ -517,7 +508,8 @@ class FusedMoEModularKernel(torch.nn.Module):
workspace13=workspace13,
workspace2=workspace2,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input)
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args)
return fused_out
@@ -541,6 +533,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
@@ -568,7 +561,8 @@ class FusedMoEModularKernel(torch.nn.Module):
a1q_scale=a1q_scale,
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input)
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args)
# Chunking required case
assert num_chunks > 1
@@ -624,6 +618,15 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_num_tokens=c_expert_num_tokens,
expert_num_tokens_cpu=c_expert_num_tokens_cpu)
m = None
if extra_expert_args is not None and 'm' in extra_expert_args:
m = extra_expert_args.get('m')
if extra_expert_args is not None:
chunked_extra_expert_args = extra_expert_args
else:
chunked_extra_expert_args = {}
for chunk_idx in range(num_chunks):
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
slice_input_tensors(chunk_idx))
@@ -634,6 +637,11 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_tokens_meta, c_topk_ids, local_num_experts,
expert_map)
s = chunk_idx * CHUNK_SIZE
e = min(s + CHUNK_SIZE, M)
if m is not None:
chunked_extra_expert_args['m'] = e - s
self._do_fused_experts(
fused_out=slice_output_tensor(chunk_idx),
a1=a1,
@@ -653,7 +661,8 @@ class FusedMoEModularKernel(torch.nn.Module):
a1q_scale=c_a1q_scale,
a2_scale=c_a2_scale,
expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input)
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=chunked_extra_expert_args)
return fused_out
@@ -675,6 +684,9 @@ class FusedMoEModularKernel(torch.nn.Module):
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
extra_expert_args: Optional[dict] = None,
extra_prepare_args: Optional[dict] = None,
extra_finalize_args: Optional[dict] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
@@ -707,6 +719,12 @@ class FusedMoEModularKernel(torch.nn.Module):
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
- extra_expert_args (Optional[dict]): Extra keyword arguments to pass to
fused_experts.apply.
- extra_prepare_args (Optional[dict]): Extra keyword arguments to pass
to prepare.
- extra_finalize_args (Optional[dict]): Extra keyword arguments to pass
to finalize.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
@@ -730,6 +748,7 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
extra_prepare_args,
)
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
@@ -766,11 +785,13 @@ class FusedMoEModularKernel(torch.nn.Module):
a1q_scale=a1q_scale,
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input)
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args)
self.prepare_finalize.finalize(
output, fused_out, topk_weights, topk_ids,
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl())
self.fused_experts.finalize_weight_and_reduce_impl(),
extra_finalize_args)
return output