[MoE] Nvfp4 Masked Gemm: Add flashinfer grouped_gemm_nt_masked (#25990)
Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import deep_ep
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
@@ -27,6 +28,8 @@ logger = init_logger(__name__)
|
||||
DEEPEP_QUANT_BLOCK_SIZE = 128
|
||||
DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def dequant_fp8(
|
||||
expert_x_fp8: torch.Tensor, expert_x_scales: torch.Tensor
|
||||
@@ -187,16 +190,25 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
# TODO (varun): Optimization - Use a batched version of quant
|
||||
x = x.view((-1, hidden_dim))
|
||||
q_dtype = quant_config.quant_dtype
|
||||
|
||||
if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm":
|
||||
logger.info_once(
|
||||
"Skip quantization when using FlashInfer CUTEDSL(masked_gemm) "
|
||||
"for ModelOptNvFp4FusedMoE."
|
||||
)
|
||||
q_dtype = None
|
||||
|
||||
x, x_scales = moe_kernel_quantize_input(
|
||||
x,
|
||||
quant_config.a1_scale,
|
||||
quant_config.quant_dtype,
|
||||
q_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
)
|
||||
x = x.view((num_experts, -1, hidden_dim))
|
||||
|
||||
if quant_config.quant_dtype is not None:
|
||||
if q_dtype is not None:
|
||||
assert x_scales is not None
|
||||
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
|
||||
|
||||
|
||||
346
vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
Normal file
346
vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate,
|
||||
)
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_cutedsl_grouped_gemm_nt_masked,
|
||||
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
|
||||
scaled_fp4_grouped_quantize,
|
||||
silu_and_mul_scaled_nvfp4_experts_quantize,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def is_valid_flashinfer_cutedsl_fused_moe(
|
||||
hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the given problem size is supported by the FlashInfer CuteDSL MoE
|
||||
kernel.
|
||||
"""
|
||||
if not has_flashinfer_cutedsl_grouped_gemm_nt_masked():
|
||||
logger.debug_once(
|
||||
"FlashInferCuteDSLExperts disabled: "
|
||||
"flashinfer_cutedsl_fused_moe not available."
|
||||
)
|
||||
return False
|
||||
# Data type checks
|
||||
if (
|
||||
w1.dtype != torch.uint8
|
||||
or w2.dtype != torch.uint8
|
||||
or hidden_states.dtype not in [torch.float32, torch.float16, torch.bfloat16]
|
||||
):
|
||||
logger.debug_once(
|
||||
"FlashInferCuteDSLExperts disabled: w1/w2 must be torch.uint8 "
|
||||
f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be "
|
||||
f"float32, float16, or bfloat16 (got {hidden_states.dtype})."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(
|
||||
self,
|
||||
out_dtype: torch.dtype,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(quant_config)
|
||||
assert quant_config.quant_dtype == "nvfp4", (
|
||||
"Only nvfp4 quantization are currently supported."
|
||||
)
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self,
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
return (
|
||||
mk.FusedMoEActivationFormat.BatchedExperts,
|
||||
mk.FusedMoEActivationFormat.BatchedExperts,
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
# This refers to TP chunking; DP chunking is handled separately.
|
||||
# TODO(shuw@nvidia.com): Set to False to be consistent with
|
||||
# batched_deep_gemm_moe
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# We use global_num_experts due to how moe_align_block_size handles
|
||||
# expert_maps.
|
||||
"""
|
||||
Compute the shapes for the temporary and final outputs of the two gemms
|
||||
and activation in the fused expert function. Since the gemms are
|
||||
independent, the workspace for the first gemm can be shared with the
|
||||
workspace for the last gemm.
|
||||
|
||||
Returns a tuple of:
|
||||
- workspace13 shape tuple: must be large enough to hold the
|
||||
result of either expert gemm.
|
||||
- workspace2 shape tuple: must be large enough to hold the
|
||||
result of the activation function.
|
||||
- output shape tuple: must be exact size of the final gemm output.
|
||||
- Workspace type: The dtype to use for the workspace tensors.
|
||||
- Note: in order for activation chunking to work, the first dimension
|
||||
of each tuple must be the number of tokens.
|
||||
"""
|
||||
output_shape = (local_num_experts, M, K)
|
||||
workspace2 = (local_num_experts, M, N)
|
||||
workspace1 = output_shape
|
||||
return (workspace1, workspace2, output_shape)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None, # Not used
|
||||
workspace13: torch.Tensor | None,
|
||||
workspace2: torch.Tensor | None,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool | None,
|
||||
):
|
||||
assert self.quant_dtype == "nvfp4", (
|
||||
"Only nvfp4 quantization are currently supported."
|
||||
)
|
||||
# Ensure w1_scale and w2_scale are not None before calling view
|
||||
assert self.w1_scale is not None and self.w2_scale is not None, (
|
||||
"w1_scale and w2_scale must not be None for FlashInferExperts"
|
||||
)
|
||||
assert expert_tokens_meta is not None
|
||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||
assert hidden_states.ndim == 3
|
||||
assert self.w1_scale.ndim == 3
|
||||
assert self.w2_scale.ndim == 3
|
||||
flashinfer_cutedsl_moe_masked(
|
||||
hidden_states=hidden_states,
|
||||
input_global_scale=self.a1_gscale,
|
||||
w1=w1,
|
||||
w1_blockscale=self.w1_scale,
|
||||
w1_alpha=self.g1_alphas,
|
||||
w2=w2,
|
||||
a2_global_scale=self.a2_gscale,
|
||||
w2_blockscale=self.w2_scale,
|
||||
w2_alpha=self.g2_alphas,
|
||||
masked_m=expert_num_tokens,
|
||||
workspace=workspace2,
|
||||
out=output,
|
||||
)
|
||||
|
||||
|
||||
def get_cute_dtype(input: torch.Tensor) -> str:
|
||||
if input.dtype == torch.bfloat16:
|
||||
return "bfloat16"
|
||||
elif input.dtype == torch.float16:
|
||||
return "float16"
|
||||
elif input.dtype == torch.float32:
|
||||
return "float32"
|
||||
else:
|
||||
raise ValueError(f"Unsupported cute dtype {input.dtype}")
|
||||
|
||||
|
||||
def flashinfer_cutedsl_moe_masked(
|
||||
hidden_states: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_blockscale: torch.Tensor,
|
||||
w1_alpha,
|
||||
w2: torch.Tensor,
|
||||
a2_global_scale: torch.Tensor,
|
||||
w2_blockscale: torch.Tensor,
|
||||
w2_alpha,
|
||||
masked_m: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
|
||||
kernels.
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): [num_experts, m, k], bf16
|
||||
input_global_scale (torch.Tensor): (l,)
|
||||
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
|
||||
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
|
||||
w1_alpha (torch.Tensor): (l,)
|
||||
w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8
|
||||
a2_global_scale (torch.Tensor): (l,)
|
||||
w2_blockscale (torch.Tensor): blockscale factors, e4m3,
|
||||
w2_alpha (torch.Tensor): (l,)
|
||||
masked_m (torch.Tensor): Masked dimension indices
|
||||
workspace (torch.Tensor): For gateup_output
|
||||
|
||||
Notes:
|
||||
- Assumes max(masked_m) <= m.
|
||||
"""
|
||||
|
||||
# === Assertions on dtypes ===
|
||||
assert input_global_scale.dtype == torch.float32, (
|
||||
f"input_global_scale must be float32, got {input_global_scale.dtype}"
|
||||
)
|
||||
assert w1.dtype == torch.uint8, f"w1 must be uint8, got {w1.dtype}"
|
||||
assert w1_blockscale.dtype == torch.float8_e4m3fn, (
|
||||
f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
|
||||
)
|
||||
assert w1_alpha.dtype == torch.float32, (
|
||||
f"w1_alpha must be float32, got {w1_alpha.dtype}"
|
||||
)
|
||||
assert w2.dtype == torch.uint8, f"w2 must be uint8, got {w2.dtype}"
|
||||
assert a2_global_scale.dtype == torch.float32, (
|
||||
f"a2_global_scale must be float32, got {a2_global_scale.dtype}"
|
||||
)
|
||||
assert w2_blockscale.dtype == torch.float8_e4m3fn, (
|
||||
f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}"
|
||||
)
|
||||
assert w2_alpha.dtype == torch.float32, (
|
||||
f"w2_alpha must be float32, got {w2_alpha.dtype}"
|
||||
)
|
||||
|
||||
# === Assertions on shapes ===
|
||||
n = w2.shape[-1] * 2 # intermediate dimension
|
||||
num_experts, m, k = hidden_states.shape
|
||||
|
||||
assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
|
||||
assert w1.shape[-1] * 2 == k, (
|
||||
f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}"
|
||||
)
|
||||
assert w2.shape[-2:] == (
|
||||
k,
|
||||
n // 2,
|
||||
), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n // 2)}"
|
||||
|
||||
assert input_global_scale.shape == (num_experts,), (
|
||||
f"input_global_scale must be (l,), got {input_global_scale.shape}"
|
||||
)
|
||||
assert w1_alpha.shape == (num_experts,), (
|
||||
f"w1_alpha must be (l,), got {w1_alpha.shape}"
|
||||
)
|
||||
assert a2_global_scale.shape == (num_experts,), (
|
||||
f"a2_global_scale must be (l,), got {a2_global_scale.shape}"
|
||||
)
|
||||
assert w2_alpha.shape == (num_experts,), (
|
||||
f"w2_alpha must be (l,), got {w2_alpha.shape}"
|
||||
)
|
||||
|
||||
aq, aq_sf = scaled_fp4_grouped_quantize(
|
||||
hidden_states,
|
||||
masked_m,
|
||||
input_global_scale,
|
||||
)
|
||||
|
||||
workspace = workspace.permute(1, 2, 0) # requirement of kernel
|
||||
sf_vec_size = 16
|
||||
assert aq_sf.dtype == torch.float8_e4m3fn
|
||||
assert aq.dtype == torch.uint8
|
||||
ab_dtype = "float4_e2m1fn"
|
||||
sf_dtype = "float8_e4m3fn"
|
||||
|
||||
c_dtype = get_cute_dtype(hidden_states)
|
||||
|
||||
# Gemm1
|
||||
flashinfer_cutedsl_grouped_gemm_nt_masked(
|
||||
(aq, aq_sf),
|
||||
(w1.permute(1, 2, 0), w1_blockscale),
|
||||
workspace,
|
||||
masked_m,
|
||||
ab_dtype=ab_dtype,
|
||||
sf_dtype=sf_dtype,
|
||||
c_dtype=c_dtype,
|
||||
sf_vec_size=sf_vec_size,
|
||||
alpha=w1_alpha.view(1, 1, num_experts),
|
||||
alpha_dtype=get_cute_dtype(w1_alpha),
|
||||
) # in logical [m, n, l]
|
||||
|
||||
# SILU and quantization
|
||||
diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize(
|
||||
workspace.permute(2, 0, 1),
|
||||
masked_m,
|
||||
a2_global_scale,
|
||||
)
|
||||
|
||||
# Gemm2
|
||||
out = out.permute(1, 2, 0) # requirement of kernel
|
||||
flashinfer_cutedsl_grouped_gemm_nt_masked(
|
||||
(diq, diq_sf),
|
||||
(w2.permute(1, 2, 0), w2_blockscale),
|
||||
out,
|
||||
masked_m,
|
||||
ab_dtype=ab_dtype,
|
||||
sf_dtype=sf_dtype,
|
||||
c_dtype=c_dtype,
|
||||
sf_vec_size=sf_vec_size,
|
||||
alpha=w2_alpha.view(1, 1, num_experts),
|
||||
alpha_dtype=get_cute_dtype(w2_alpha),
|
||||
) # in logical [m, k, l]
|
||||
out = out.permute(2, 0, 1)
|
||||
|
||||
|
||||
def flashinfer_cutedsl_moe_fp4(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
create_flashinfer_prepare_finalize,
|
||||
)
|
||||
|
||||
fused_experts = mk.FusedMoEModularKernel(
|
||||
create_flashinfer_prepare_finalize(use_dp=False), # could be swapped later
|
||||
FlashInferCuteDSLExperts(
|
||||
out_dtype=hidden_states.dtype,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
@@ -1468,7 +1468,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
gemm1_weight = layer.w13_weight.data
|
||||
gemm1_weight_scale = layer.w13_weight_scale.data
|
||||
|
||||
if self.allow_flashinfer:
|
||||
if (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||
):
|
||||
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
|
||||
gemm1_weight, gemm1_weight_scale, dim=-2
|
||||
)
|
||||
@@ -1746,17 +1749,26 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
workspace=layer.workspace,
|
||||
)
|
||||
|
||||
elif (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
flashinfer_cutlass_moe_fp4,
|
||||
elif self.allow_flashinfer:
|
||||
assert self.flashinfer_moe_backend in (
|
||||
FlashinferMoeBackend.CUTLASS,
|
||||
FlashinferMoeBackend.CUTEDSL,
|
||||
)
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
flashinfer_cutlass_moe_fp4,
|
||||
)
|
||||
|
||||
flashinfer_fn_moe_fp4 = flashinfer_cutlass_moe_fp4
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( # noqa: E501
|
||||
flashinfer_cutedsl_moe_fp4,
|
||||
)
|
||||
|
||||
flashinfer_fn_moe_fp4 = flashinfer_cutedsl_moe_fp4
|
||||
|
||||
assert self.moe_quant_config is not None
|
||||
|
||||
return flashinfer_cutlass_moe_fp4(
|
||||
return flashinfer_fn_moe_fp4(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
|
||||
@@ -10,6 +10,9 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
|
||||
FlashInferCuteDSLExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
)
|
||||
@@ -17,10 +20,14 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize im
|
||||
create_flashinfer_prepare_finalize,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
from vllm.utils.flashinfer import (
|
||||
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
|
||||
has_flashinfer_cutlass_fused_moe,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"is_flashinfer_fp4_cutlass_moe_available",
|
||||
"is_flashinfer_fp4_cutedsl_moe_available",
|
||||
"reorder_w1w3_to_w3w1",
|
||||
"build_flashinfer_fp4_cutlass_moe_prepare_finalize",
|
||||
]
|
||||
@@ -36,6 +43,16 @@ def is_flashinfer_fp4_cutlass_moe_available() -> bool:
|
||||
)
|
||||
|
||||
|
||||
def is_flashinfer_fp4_cutedsl_moe_available() -> bool:
|
||||
"""Return ``True`` when FlashInfer CUTEDSL NV-FP4 kernels can be used."""
|
||||
return (
|
||||
envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||
and has_flashinfer_cutedsl_grouped_gemm_nt_masked()
|
||||
and current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100)
|
||||
)
|
||||
|
||||
|
||||
def reorder_w1w3_to_w3w1(
|
||||
weight: torch.Tensor, scale: torch.Tensor, dim: int = -2
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -72,15 +89,21 @@ def select_nvfp4_gemm_impl(
|
||||
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
|
||||
|
||||
if allow_flashinfer:
|
||||
return FlashInferExperts(
|
||||
out_dtype=moe.in_dtype,
|
||||
quant_config=moe_quant_config,
|
||||
ep_rank=moe.moe_parallel_config.ep_rank,
|
||||
ep_size=moe.moe_parallel_config.ep_size,
|
||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||
tp_size=moe.moe_parallel_config.tp_size,
|
||||
use_dp=moe.moe_parallel_config.dp_size > 1,
|
||||
)
|
||||
if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm":
|
||||
return FlashInferCuteDSLExperts(
|
||||
out_dtype=moe.in_dtype,
|
||||
quant_config=moe_quant_config,
|
||||
)
|
||||
elif envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput":
|
||||
return FlashInferExperts(
|
||||
out_dtype=moe.in_dtype,
|
||||
quant_config=moe_quant_config,
|
||||
ep_rank=moe.moe_parallel_config.ep_rank,
|
||||
ep_size=moe.moe_parallel_config.ep_size,
|
||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||
tp_size=moe.moe_parallel_config.tp_size,
|
||||
use_dp=moe.moe_parallel_config.dp_size > 1,
|
||||
)
|
||||
|
||||
# native cutlass experts currently don't support DP; TP case won't call this
|
||||
raise ValueError(
|
||||
|
||||
@@ -25,6 +25,7 @@ logger = init_logger(__name__)
|
||||
class FlashinferMoeBackend(Enum):
|
||||
TENSORRT_LLM = "TensorRT-LLM"
|
||||
CUTLASS = "CUTLASS"
|
||||
CUTEDSL = "CUTEDSL"
|
||||
|
||||
|
||||
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||
@@ -273,19 +274,21 @@ def flashinfer_cutlass_moe_fp8(
|
||||
|
||||
|
||||
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
||||
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
|
||||
# Prefer CUTLASS on SM90 to cover both SM90/SM100 generations
|
||||
if flashinfer_moe_backend == "throughput" or current_platform.is_device_capability(
|
||||
90
|
||||
):
|
||||
return FlashinferMoeBackend.CUTLASS
|
||||
elif flashinfer_moe_backend == "latency":
|
||||
return FlashinferMoeBackend.TENSORRT_LLM
|
||||
backend_map = {
|
||||
"throughput": FlashinferMoeBackend.CUTLASS,
|
||||
"latency": FlashinferMoeBackend.TENSORRT_LLM,
|
||||
"masked_gemm": FlashinferMoeBackend.CUTEDSL,
|
||||
}
|
||||
|
||||
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
|
||||
if flashinfer_moe_backend in backend_map:
|
||||
return backend_map[flashinfer_moe_backend]
|
||||
elif current_platform.is_device_capability(90):
|
||||
return FlashinferMoeBackend.CUTLASS
|
||||
|
||||
allowed_backends = ["throughput", "latency"]
|
||||
raise ValueError(
|
||||
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
|
||||
f" expected one of {allowed_backends}"
|
||||
f"Unknown flashinfer moe backend: {flashinfer_moe_backend!r}. "
|
||||
f"Expected one of {list(backend_map.keys())}."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
is_flashinfer_fp4_cutedsl_moe_available,
|
||||
is_flashinfer_fp4_cutlass_moe_available,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
@@ -32,7 +33,10 @@ def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support:
|
||||
"""Detect platform support for NV-FP4 fused-MoE path"""
|
||||
cutlass_supported = cutlass_fp4_supported()
|
||||
|
||||
allow_flashinfer = cutlass_supported and is_flashinfer_fp4_cutlass_moe_available()
|
||||
allow_flashinfer = cutlass_supported and (
|
||||
is_flashinfer_fp4_cutlass_moe_available()
|
||||
or is_flashinfer_fp4_cutedsl_moe_available()
|
||||
)
|
||||
|
||||
if allow_flashinfer:
|
||||
_logger.info_once(
|
||||
|
||||
Reference in New Issue
Block a user