[Core] FlashInfer CUTLASS fused MoE backend (NVFP4) (#20037)
Signed-off-by: shuw <shuw@nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
""" CUTLASS based Fused MoE kernels."""
|
||||
from typing import Callable, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -14,7 +14,8 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
|
||||
_resize_cache)
|
||||
_resize_cache,
|
||||
extract_required_args)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -298,7 +299,8 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool):
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]]):
|
||||
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||
|
||||
@@ -431,23 +433,28 @@ FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
|
||||
def cutlass_moe_fp4(a: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
w1_fp4: torch.Tensor,
|
||||
w1_blockscale: torch.Tensor,
|
||||
w1_alphas: torch.Tensor,
|
||||
a2_gscale: torch.Tensor,
|
||||
w2_fp4: torch.Tensor,
|
||||
w2_blockscale: torch.Tensor,
|
||||
w2_alphas: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
device: torch.device,
|
||||
apply_router_weight_on_input: bool = False):
|
||||
def run_cutlass_moe_fp4(
|
||||
output: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
w1_fp4: torch.Tensor,
|
||||
w1_blockscale: torch.Tensor,
|
||||
w1_alphas: torch.Tensor,
|
||||
a2_gscale: torch.Tensor,
|
||||
w2_fp4: torch.Tensor,
|
||||
w2_blockscale: torch.Tensor,
|
||||
w2_alphas: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
device: torch.device,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
MoE implementation for FP4 Inputs
|
||||
|
||||
@@ -487,16 +494,16 @@ def cutlass_moe_fp4(a: torch.Tensor,
|
||||
|
||||
assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match",
|
||||
" between weights.")
|
||||
assert (k_a // 2 == half_k_w1
|
||||
assert (k_a == half_k_w1 * 2
|
||||
and k == k_w2), ("Hidden size mismatch between a, w1 and w2")
|
||||
assert (nx2_w1 == n * 2 and half_n_w2 == n // 2), ("mismatch in "
|
||||
"expected `n`")
|
||||
assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in "
|
||||
"expected `n`")
|
||||
assert (m == m_a), "input shape mismatch"
|
||||
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
|
||||
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
|
||||
assert (topk_weights.size(0) == m and topk_ids.size(0)
|
||||
== m), ("topk must be provided for each row of a")
|
||||
|
||||
topk = topk_ids.size(1)
|
||||
out_dtype = a.dtype
|
||||
num_topk = topk_ids.size(1)
|
||||
|
||||
@@ -523,7 +530,6 @@ def cutlass_moe_fp4(a: torch.Tensor,
|
||||
blockscale_offsets)
|
||||
|
||||
a = ops.shuffle_rows(a, a_map)
|
||||
|
||||
rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant(
|
||||
a,
|
||||
a1_gscale,
|
||||
@@ -531,34 +537,220 @@ def cutlass_moe_fp4(a: torch.Tensor,
|
||||
blockscale_offsets,
|
||||
num_topk,
|
||||
)
|
||||
|
||||
c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale,
|
||||
w1_blockscale, w1_alphas, problem_sizes1,
|
||||
expert_offsets[:-1], blockscale_offsets[:-1],
|
||||
out_dtype, device)
|
||||
c1 = _resize_cache(workspace13, (m * topk, n * 2))
|
||||
c2 = _resize_cache(workspace2, (m * topk, n))
|
||||
c3 = _resize_cache(workspace13, (m * topk, k))
|
||||
ops.cutlass_fp4_moe_mm(c1, rep_a_fp4, w1_fp4, rep_a_blockscale,
|
||||
w1_blockscale, w1_alphas, problem_sizes1,
|
||||
expert_offsets[:-1], blockscale_offsets[:-1])
|
||||
del rep_a_fp4, rep_a_blockscale
|
||||
# hidden size dimension is split to one halfpytho sized tensor.
|
||||
intermediate = torch.empty((m * num_topk, w1_fp4.size(1) // 2),
|
||||
device=device,
|
||||
dtype=out_dtype)
|
||||
|
||||
torch.ops._C.silu_and_mul(intermediate, c1)
|
||||
|
||||
torch.ops._C.silu_and_mul(c2, c1)
|
||||
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
|
||||
intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk)
|
||||
c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk)
|
||||
|
||||
c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale,
|
||||
w2_alphas, problem_sizes2, expert_offsets[:-1],
|
||||
blockscale_offsets[:-1], out_dtype, device)
|
||||
ops.cutlass_fp4_moe_mm(c3, int_fp4, w2_fp4, int_blockscale, w2_blockscale,
|
||||
w2_alphas, problem_sizes2, expert_offsets[:-1],
|
||||
blockscale_offsets[:-1])
|
||||
del int_fp4, int_blockscale
|
||||
|
||||
c2 = ops.shuffle_rows(c2, c_map)
|
||||
c3 = ops.shuffle_rows(c3, c_map)
|
||||
|
||||
assert output.dtype == out_dtype
|
||||
if not apply_router_weight_on_input:
|
||||
out = (c2.view(m, num_topk, k) *
|
||||
topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1)
|
||||
output.copy_(
|
||||
(c3.view(m, num_topk, k) *
|
||||
topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1),
|
||||
non_blocking=True)
|
||||
else:
|
||||
out = c2.view(m, num_topk, k).sum(dim=1)
|
||||
return out.to(dtype=out_dtype)
|
||||
output.copy_(c3.view(m, num_topk, k).sum(dim=1), non_blocking=True)
|
||||
return
|
||||
|
||||
|
||||
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_experts_per_worker: int,
|
||||
out_dtype: torch.dtype,
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
use_batched_format: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig(
|
||||
quant_dtype=torch.uint8,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_out_ch_quant,
|
||||
block_shape=block_shape,
|
||||
))
|
||||
self.max_experts_per_worker = max_experts_per_worker
|
||||
self.out_dtype = out_dtype
|
||||
self.use_batched_format = use_batched_format
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
if self.use_batched_format:
|
||||
return (mk.FusedMoEActivationFormat.BatchedExperts,
|
||||
mk.FusedMoEActivationFormat.BatchedExperts)
|
||||
else:
|
||||
return (mk.FusedMoEActivationFormat.Standard,
|
||||
mk.FusedMoEActivationFormat.Standard)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
workspace1: tuple[int, ...] = ()
|
||||
workspace2: tuple[int, ...] = ()
|
||||
output: tuple[int, ...] = ()
|
||||
if self.use_batched_format:
|
||||
padded_M = aq.size(1)
|
||||
workspace1 = (self.max_experts_per_worker, padded_M, max(N, K))
|
||||
workspace2 = (self.max_experts_per_worker, padded_M, (N // 2))
|
||||
output = (self.max_experts_per_worker, padded_M, K)
|
||||
else:
|
||||
workspace1 = (M * topk, max(2 * N, K))
|
||||
workspace2 = (M * topk, N)
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output,
|
||||
self.out_dtype if self.out_dtype is not None else a.dtype)
|
||||
|
||||
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor], w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor, w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: torch.Tensor, workspace13: Optional[torch.Tensor],
|
||||
workspace2: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]]):
|
||||
required_keys = [
|
||||
"g1_alphas", "g2_alphas", "a1_gscale", "a2_gscale", "m", "n", "k",
|
||||
"e", "device"
|
||||
]
|
||||
(g1_alphas, g2_alphas, a1_gscale, a2_gscale, m, n, k, e,
|
||||
device) = extract_required_args(extra_expert_args, required_keys)
|
||||
run_cutlass_moe_fp4(
|
||||
output=output,
|
||||
a=hidden_states,
|
||||
a1_gscale=a1_gscale,
|
||||
w1_fp4=w1,
|
||||
w1_blockscale=w1_scale,
|
||||
w1_alphas=g1_alphas,
|
||||
a2_gscale=a2_gscale,
|
||||
w2_fp4=w2,
|
||||
w2_blockscale=w2_scale,
|
||||
w2_alphas=g2_alphas,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
e=e,
|
||||
device=device,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
def cutlass_moe_fp4(
|
||||
a: torch.Tensor,
|
||||
w1_fp4: torch.Tensor,
|
||||
w2_fp4: torch.Tensor,
|
||||
w1_blockscale: torch.Tensor,
|
||||
w2_blockscale: torch.Tensor,
|
||||
g1_alphas: torch.Tensor,
|
||||
g2_alphas: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
a2_gscale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
device: torch.device,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False) -> torch.Tensor:
|
||||
assert expert_map is None, ("Expert Parallelism / expert_map "
|
||||
"is currently not supported for "
|
||||
"ModelOptNvFp4FusedMoE's cutlass_moe_fp4.")
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
CutlassExpertsFp4(
|
||||
max_experts_per_worker=e,
|
||||
out_dtype=a.dtype,
|
||||
per_act_token_quant=False,
|
||||
per_out_ch_quant=False,
|
||||
use_batched_format=False,
|
||||
),
|
||||
)
|
||||
extra_expert_args = {
|
||||
'g1_alphas': g1_alphas,
|
||||
'g2_alphas': g2_alphas,
|
||||
'a1_gscale': a1_gscale,
|
||||
'a2_gscale': a2_gscale,
|
||||
'm': m,
|
||||
'n': n,
|
||||
'k': k,
|
||||
'e': e,
|
||||
'device': device,
|
||||
}
|
||||
|
||||
# NVFP4 requires two levels of quantization, which involves computing some
|
||||
# scaling factors dynamically. This makes it incompatible with the typical
|
||||
# prepare -> MoE -> finalize pipeline. Move the quantization logic into the
|
||||
# MoE body.
|
||||
extra_prepare_args = {
|
||||
'skip_quant': True,
|
||||
}
|
||||
# Similar reason as above.
|
||||
extra_finalize_args = {
|
||||
'skip_weight_reduce': True,
|
||||
}
|
||||
return fn(
|
||||
hidden_states=a,
|
||||
w1=w1_fp4,
|
||||
w2=w2_fp4,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
w1_scale=w1_blockscale,
|
||||
w2_scale=w2_blockscale,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=extra_expert_args,
|
||||
extra_prepare_args=extra_prepare_args,
|
||||
extra_finalize_args=extra_finalize_args,
|
||||
)
|
||||
|
||||
|
||||
def _valid_cutlass_block_scaled_grouped_gemm(
|
||||
|
||||
Reference in New Issue
Block a user