[Quantization] Add FlashInfer CuteDSL batched experts backend for NVFP4 MoE (#38251)
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -17,7 +17,7 @@ from flashinfer import fp4_quantize
|
||||
from torch.nn import functional as F
|
||||
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_moe import (
|
||||
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_batched_moe import ( # noqa: E501
|
||||
flashinfer_cutedsl_moe_masked,
|
||||
)
|
||||
from vllm.utils.flashinfer import (
|
||||
|
||||
@@ -0,0 +1,353 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kNvfp4Dynamic,
|
||||
kNvfp4Static,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_cutedsl_grouped_gemm_nt_masked,
|
||||
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
|
||||
scaled_fp4_grouped_quantize,
|
||||
silu_and_mul_scaled_nvfp4_experts_quantize,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashInferCuteDSLBatchedExperts(mk.FusedMoEExpertsModular):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
max_num_tokens: int,
|
||||
num_dispatchers: int,
|
||||
):
|
||||
super().__init__(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=num_dispatchers,
|
||||
)
|
||||
assert quant_config.quant_dtype == "nvfp4", (
|
||||
"Only nvfp4 quantization are currently supported."
|
||||
)
|
||||
self.out_dtype = moe_config.in_dtype
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale)
|
||||
layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale)
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
p = current_platform
|
||||
return (
|
||||
p.is_cuda()
|
||||
and p.is_device_capability_family(100)
|
||||
and has_flashinfer_cutedsl_grouped_gemm_nt_masked()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
SUPPORTED_W_A = [
|
||||
(kNvfp4Static, kNvfp4Dynamic),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation == MoEActivation.SILU
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
"""
|
||||
Compute the shapes for the temporary and final outputs of the two gemms
|
||||
and activation in the fused expert function. Since the gemms are
|
||||
independent, the workspace for the first gemm can be shared with the
|
||||
workspace for the last gemm.
|
||||
|
||||
Returns a tuple of:
|
||||
- workspace13 shape tuple: must be large enough to hold the
|
||||
result of either expert gemm.
|
||||
- workspace2 shape tuple: must be large enough to hold the
|
||||
result of the activation function.
|
||||
- output shape tuple: must be exact size of the final gemm output.
|
||||
- Workspace type: The dtype to use for the workspace tensors.
|
||||
- Note: in order for activation chunking to work, the first dimension
|
||||
of each tuple must be the number of tokens.
|
||||
"""
|
||||
|
||||
# We use global_num_experts due to how moe_align_block_size handles
|
||||
# expert_maps.
|
||||
K_dim = K * 2 if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else K
|
||||
output_shape = (local_num_experts, M, K_dim)
|
||||
workspace2 = (local_num_experts, M, N)
|
||||
workspace1 = output_shape
|
||||
return (workspace1, workspace2, output_shape)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None, # Not used
|
||||
workspace13: torch.Tensor | None,
|
||||
workspace2: torch.Tensor | None,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool | None,
|
||||
):
|
||||
assert self.quant_dtype == "nvfp4", (
|
||||
"Only nvfp4 quantization are currently supported."
|
||||
)
|
||||
# Ensure w1_scale and w2_scale are not None before calling view
|
||||
assert self.w1_scale is not None and self.w2_scale is not None, (
|
||||
"w1_scale and w2_scale must not be None for FlashInferExperts"
|
||||
)
|
||||
assert expert_tokens_meta is not None
|
||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||
assert hidden_states.ndim == 3
|
||||
assert self.w1_scale.ndim == 3
|
||||
assert self.w2_scale.ndim == 3
|
||||
|
||||
input_global_scale = (
|
||||
None if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else self.a1_gscale
|
||||
)
|
||||
flashinfer_hidden_states = (
|
||||
(hidden_states, a1q_scale)
|
||||
if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH
|
||||
else hidden_states
|
||||
)
|
||||
flashinfer_cutedsl_moe_masked(
|
||||
hidden_states=flashinfer_hidden_states,
|
||||
input_global_scale=input_global_scale,
|
||||
w1=w1,
|
||||
w1_blockscale=self.w1_scale,
|
||||
w1_alpha=self.g1_alphas,
|
||||
w2=w2,
|
||||
a2_global_scale=self.a2_gscale,
|
||||
w2_blockscale=self.w2_scale,
|
||||
w2_alpha=self.g2_alphas,
|
||||
masked_m=expert_num_tokens,
|
||||
workspace=workspace2,
|
||||
out=output,
|
||||
)
|
||||
|
||||
|
||||
def get_cute_dtype(input: torch.Tensor) -> str:
|
||||
if input.dtype == torch.bfloat16:
|
||||
return "bfloat16"
|
||||
elif input.dtype == torch.float16:
|
||||
return "float16"
|
||||
elif input.dtype == torch.float32:
|
||||
return "float32"
|
||||
else:
|
||||
raise ValueError(f"Unsupported cute dtype {input.dtype}")
|
||||
|
||||
|
||||
def flashinfer_cutedsl_moe_masked(
|
||||
hidden_states: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
input_global_scale: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_blockscale: torch.Tensor,
|
||||
w1_alpha,
|
||||
w2: torch.Tensor,
|
||||
a2_global_scale: torch.Tensor,
|
||||
w2_blockscale: torch.Tensor,
|
||||
w2_alpha,
|
||||
masked_m: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
|
||||
kernels.
|
||||
|
||||
Args:
|
||||
hidden_states: Either of the following case
|
||||
* torch.Tensor: [num_experts, m, k], bf16
|
||||
* tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2],
|
||||
uint8, [num_experts, m, k // 16], float8_e4m3fn
|
||||
input_global_scale (torch.Tensor): (l,)
|
||||
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
|
||||
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
|
||||
w1_alpha (torch.Tensor): (l,)
|
||||
w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8
|
||||
a2_global_scale (torch.Tensor): (l,)
|
||||
w2_blockscale (torch.Tensor): blockscale factors, e4m3,
|
||||
w2_alpha (torch.Tensor): (l,)
|
||||
masked_m (torch.Tensor): Masked dimension indices
|
||||
workspace (torch.Tensor): For gateup_output
|
||||
|
||||
Notes:
|
||||
- Assumes max(masked_m) <= m.
|
||||
"""
|
||||
|
||||
# === Assertions on dtypes ===
|
||||
assert w1.dtype == torch.uint8, f"w1 must be uint8, got {w1.dtype}"
|
||||
assert w1_blockscale.dtype == torch.float8_e4m3fn, (
|
||||
f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
|
||||
)
|
||||
assert w1_alpha.dtype == torch.float32, (
|
||||
f"w1_alpha must be float32, got {w1_alpha.dtype}"
|
||||
)
|
||||
assert w2.dtype == torch.uint8, f"w2 must be uint8, got {w2.dtype}"
|
||||
assert a2_global_scale.dtype == torch.float32, (
|
||||
f"a2_global_scale must be float32, got {a2_global_scale.dtype}"
|
||||
)
|
||||
assert w2_blockscale.dtype == torch.float8_e4m3fn, (
|
||||
f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}"
|
||||
)
|
||||
assert w2_alpha.dtype == torch.float32, (
|
||||
f"w2_alpha must be float32, got {w2_alpha.dtype}"
|
||||
)
|
||||
|
||||
# === Assertions on shapes ===
|
||||
n = w2.shape[-1] * 2 # intermediate dimension
|
||||
if isinstance(hidden_states, tuple):
|
||||
assert input_global_scale is None, (
|
||||
"input_global_scale is needed when input needs quant"
|
||||
)
|
||||
|
||||
aq = hidden_states[0].view(torch.uint8)
|
||||
aq_sf = hidden_states[1].view(torch.float8_e4m3fn)
|
||||
# m, k_by_2, num_experts = aq.shape
|
||||
num_experts, m, k_by_2 = aq.shape
|
||||
k = k_by_2 * 2
|
||||
aq = aq.permute(1, 2, 0)
|
||||
else:
|
||||
num_experts, m, k = hidden_states.shape
|
||||
|
||||
assert input_global_scale.dtype == torch.float32, (
|
||||
f"input_global_scale must be float32, got {input_global_scale.dtype}"
|
||||
)
|
||||
assert input_global_scale.shape == (num_experts,), (
|
||||
f"input_global_scale must be (l,), got {input_global_scale.shape}"
|
||||
)
|
||||
|
||||
aq, aq_sf = scaled_fp4_grouped_quantize(
|
||||
hidden_states,
|
||||
masked_m,
|
||||
input_global_scale,
|
||||
)
|
||||
|
||||
assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
|
||||
assert w1.shape[-1] * 2 == k, (
|
||||
f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}"
|
||||
)
|
||||
assert w2.shape[-2:] == (
|
||||
k,
|
||||
n // 2,
|
||||
), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n // 2)}"
|
||||
|
||||
assert w1_alpha.shape == (num_experts,), (
|
||||
f"w1_alpha must be (l,), got {w1_alpha.shape}"
|
||||
)
|
||||
assert a2_global_scale.shape == (num_experts,), (
|
||||
f"a2_global_scale must be (l,), got {a2_global_scale.shape}"
|
||||
)
|
||||
assert w2_alpha.shape == (num_experts,), (
|
||||
f"w2_alpha must be (l,), got {w2_alpha.shape}"
|
||||
)
|
||||
|
||||
workspace = workspace.permute(1, 2, 0) # requirement of kernel
|
||||
sf_vec_size = 16
|
||||
assert aq_sf.dtype == torch.float8_e4m3fn
|
||||
assert aq.dtype == torch.uint8
|
||||
ab_dtype = "float4_e2m1fn"
|
||||
sf_dtype = "float8_e4m3fn"
|
||||
|
||||
if isinstance(hidden_states, tuple):
|
||||
c_dtype = "bfloat16"
|
||||
else:
|
||||
c_dtype = get_cute_dtype(hidden_states)
|
||||
|
||||
# Gemm1
|
||||
flashinfer_cutedsl_grouped_gemm_nt_masked(
|
||||
(aq, aq_sf),
|
||||
(w1.permute(1, 2, 0), w1_blockscale),
|
||||
workspace,
|
||||
masked_m,
|
||||
ab_dtype=ab_dtype,
|
||||
sf_dtype=sf_dtype,
|
||||
c_dtype=c_dtype,
|
||||
sf_vec_size=sf_vec_size,
|
||||
alpha=w1_alpha.view(1, 1, num_experts),
|
||||
alpha_dtype=get_cute_dtype(w1_alpha),
|
||||
) # in logical [m, n, l]
|
||||
|
||||
# SILU and quantization
|
||||
diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize(
|
||||
workspace.permute(2, 0, 1),
|
||||
masked_m,
|
||||
a2_global_scale,
|
||||
)
|
||||
|
||||
# Gemm2
|
||||
out = out.permute(1, 2, 0) # requirement of kernel
|
||||
flashinfer_cutedsl_grouped_gemm_nt_masked(
|
||||
(diq, diq_sf),
|
||||
(w2.permute(1, 2, 0), w2_blockscale),
|
||||
out,
|
||||
masked_m,
|
||||
ab_dtype=ab_dtype,
|
||||
sf_dtype=sf_dtype,
|
||||
c_dtype=c_dtype,
|
||||
sf_vec_size=sf_vec_size,
|
||||
alpha=w2_alpha.view(1, 1, num_experts),
|
||||
alpha_dtype=get_cute_dtype(w2_alpha),
|
||||
) # in logical [m, k, l]
|
||||
out = out.permute(2, 0, 1)
|
||||
@@ -4,8 +4,6 @@
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
@@ -13,7 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate,
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
@@ -22,33 +20,42 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_cutedsl_grouped_gemm_nt_masked,
|
||||
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
|
||||
scaled_fp4_grouped_quantize,
|
||||
silu_and_mul_scaled_nvfp4_experts_quantize,
|
||||
flashinfer_cute_dsl_fused_moe_nvfp4,
|
||||
has_flashinfer_cutedsl_moe_nvfp4,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
|
||||
"""
|
||||
CuteDSL NvFP4 MoE experts using the FlashInfer functional API.
|
||||
|
||||
Uses Standard activation format (non-batched). The kernel handles
|
||||
routing, expert computation, and reduction internally.
|
||||
Supports expert parallelism natively.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
max_num_tokens: int,
|
||||
num_dispatchers: int,
|
||||
):
|
||||
super().__init__(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=num_dispatchers,
|
||||
)
|
||||
assert quant_config.quant_dtype == "nvfp4", (
|
||||
"Only nvfp4 quantization are currently supported."
|
||||
"Only nvfp4 quantization is currently supported."
|
||||
)
|
||||
self.out_dtype = moe_config.in_dtype
|
||||
self.hidden_dim = moe_config.hidden_dim
|
||||
self.intermediate_size_per_partition = (
|
||||
moe_config.intermediate_size_per_partition
|
||||
)
|
||||
self.topk = moe_config.experts_per_token
|
||||
self.local_num_experts = moe_config.num_local_experts
|
||||
self.global_num_experts = moe_config.num_experts
|
||||
self.ep_rank = moe_config.moe_parallel_config.ep_rank
|
||||
self.local_expert_offset = self.ep_rank * self.local_num_experts
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale)
|
||||
@@ -56,7 +63,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
@@ -64,7 +71,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
|
||||
return (
|
||||
p.is_cuda()
|
||||
and p.is_device_capability_family(100)
|
||||
and has_flashinfer_cutedsl_grouped_gemm_nt_masked()
|
||||
and has_flashinfer_cutedsl_moe_nvfp4()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -86,15 +93,16 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
|
||||
return activation == MoEActivation.SILU
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
def _supports_parallel_config(
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
@@ -107,29 +115,12 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# We use global_num_experts due to how moe_align_block_size handles
|
||||
# expert_maps.
|
||||
"""
|
||||
Compute the shapes for the temporary and final outputs of the two gemms
|
||||
and activation in the fused expert function. Since the gemms are
|
||||
independent, the workspace for the first gemm can be shared with the
|
||||
workspace for the last gemm.
|
||||
|
||||
Returns a tuple of:
|
||||
- workspace13 shape tuple: must be large enough to hold the
|
||||
result of either expert gemm.
|
||||
- workspace2 shape tuple: must be large enough to hold the
|
||||
result of the activation function.
|
||||
- output shape tuple: must be exact size of the final gemm output.
|
||||
- Workspace type: The dtype to use for the workspace tensors.
|
||||
- Note: in order for activation chunking to work, the first dimension
|
||||
of each tuple must be the number of tokens.
|
||||
"""
|
||||
K_dim = K * 2 if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else K
|
||||
output_shape = (local_num_experts, M, K_dim)
|
||||
workspace2 = (local_num_experts, M, N)
|
||||
workspace1 = output_shape
|
||||
return (workspace1, workspace2, output_shape)
|
||||
workspace1 = (0,)
|
||||
workspace2 = (0,)
|
||||
# K is packed (K//2 for uint8), so output uses hidden_dim.
|
||||
assert self.hidden_dim == K * 2
|
||||
output = (M, self.hidden_dim)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -143,210 +134,39 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None, # Not used
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor | None,
|
||||
workspace2: torch.Tensor | None,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool | None,
|
||||
):
|
||||
assert self.quant_dtype == "nvfp4", (
|
||||
"Only nvfp4 quantization are currently supported."
|
||||
)
|
||||
# Ensure w1_scale and w2_scale are not None before calling view
|
||||
assert self.w1_scale is not None and self.w2_scale is not None, (
|
||||
"w1_scale and w2_scale must not be None for FlashInferExperts"
|
||||
)
|
||||
assert expert_tokens_meta is not None
|
||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||
assert hidden_states.ndim == 3
|
||||
assert self.w1_scale.ndim == 3
|
||||
assert self.w2_scale.ndim == 3
|
||||
assert self.quant_dtype == "nvfp4"
|
||||
assert a1q_scale is not None
|
||||
assert self.w1_scale is not None
|
||||
assert self.w2_scale is not None
|
||||
|
||||
input_global_scale = (
|
||||
None if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else self.a1_gscale
|
||||
)
|
||||
flashinfer_hidden_states = (
|
||||
(hidden_states, a1q_scale)
|
||||
if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH
|
||||
else hidden_states
|
||||
)
|
||||
flashinfer_cutedsl_moe_masked(
|
||||
hidden_states=flashinfer_hidden_states,
|
||||
input_global_scale=input_global_scale,
|
||||
w1=w1,
|
||||
w1_blockscale=self.w1_scale,
|
||||
w1_alpha=self.g1_alphas,
|
||||
w2=w2,
|
||||
a2_global_scale=self.a2_gscale,
|
||||
w2_blockscale=self.w2_scale,
|
||||
w2_alpha=self.g2_alphas,
|
||||
masked_m=expert_num_tokens,
|
||||
workspace=workspace2,
|
||||
out=output,
|
||||
)
|
||||
# a1q_scale is (M, K//16) float8_e4m3fn from fp4_quantize.
|
||||
# The functional API expects x_sf with trailing dim: (M, K//16, 1).
|
||||
x_sf = a1q_scale.unsqueeze(-1)
|
||||
|
||||
from vllm.utils.flashinfer import _is_fi_autotuning, autotune
|
||||
|
||||
def get_cute_dtype(input: torch.Tensor) -> str:
|
||||
if input.dtype == torch.bfloat16:
|
||||
return "bfloat16"
|
||||
elif input.dtype == torch.float16:
|
||||
return "float16"
|
||||
elif input.dtype == torch.float32:
|
||||
return "float32"
|
||||
else:
|
||||
raise ValueError(f"Unsupported cute dtype {input.dtype}")
|
||||
|
||||
|
||||
def flashinfer_cutedsl_moe_masked(
|
||||
hidden_states: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
input_global_scale: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_blockscale: torch.Tensor,
|
||||
w1_alpha,
|
||||
w2: torch.Tensor,
|
||||
a2_global_scale: torch.Tensor,
|
||||
w2_blockscale: torch.Tensor,
|
||||
w2_alpha,
|
||||
masked_m: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
|
||||
kernels.
|
||||
|
||||
Args:
|
||||
hidden_states: Either of the following case
|
||||
* torch.Tensor: [num_experts, m, k], bf16
|
||||
* tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2],
|
||||
uint8, [num_experts, m, k // 16], float8_e4m3fn
|
||||
input_global_scale (torch.Tensor): (l,)
|
||||
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
|
||||
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
|
||||
w1_alpha (torch.Tensor): (l,)
|
||||
w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8
|
||||
a2_global_scale (torch.Tensor): (l,)
|
||||
w2_blockscale (torch.Tensor): blockscale factors, e4m3,
|
||||
w2_alpha (torch.Tensor): (l,)
|
||||
masked_m (torch.Tensor): Masked dimension indices
|
||||
workspace (torch.Tensor): For gateup_output
|
||||
|
||||
Notes:
|
||||
- Assumes max(masked_m) <= m.
|
||||
"""
|
||||
|
||||
# === Assertions on dtypes ===
|
||||
assert w1.dtype == torch.uint8, f"w1 must be uint8, got {w1.dtype}"
|
||||
assert w1_blockscale.dtype == torch.float8_e4m3fn, (
|
||||
f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
|
||||
)
|
||||
assert w1_alpha.dtype == torch.float32, (
|
||||
f"w1_alpha must be float32, got {w1_alpha.dtype}"
|
||||
)
|
||||
assert w2.dtype == torch.uint8, f"w2 must be uint8, got {w2.dtype}"
|
||||
assert a2_global_scale.dtype == torch.float32, (
|
||||
f"a2_global_scale must be float32, got {a2_global_scale.dtype}"
|
||||
)
|
||||
assert w2_blockscale.dtype == torch.float8_e4m3fn, (
|
||||
f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}"
|
||||
)
|
||||
assert w2_alpha.dtype == torch.float32, (
|
||||
f"w2_alpha must be float32, got {w2_alpha.dtype}"
|
||||
)
|
||||
|
||||
# === Assertions on shapes ===
|
||||
n = w2.shape[-1] * 2 # intermediate dimension
|
||||
if isinstance(hidden_states, tuple):
|
||||
assert input_global_scale is None, (
|
||||
"input_global_scale is needed when input needs quant"
|
||||
)
|
||||
|
||||
aq = hidden_states[0].view(torch.uint8)
|
||||
aq_sf = hidden_states[1].view(torch.float8_e4m3fn)
|
||||
# m, k_by_2, num_experts = aq.shape
|
||||
num_experts, m, k_by_2 = aq.shape
|
||||
k = k_by_2 * 2
|
||||
aq = aq.permute(1, 2, 0)
|
||||
else:
|
||||
num_experts, m, k = hidden_states.shape
|
||||
|
||||
assert input_global_scale.dtype == torch.float32, (
|
||||
f"input_global_scale must be float32, got {input_global_scale.dtype}"
|
||||
)
|
||||
assert input_global_scale.shape == (num_experts,), (
|
||||
f"input_global_scale must be (l,), got {input_global_scale.shape}"
|
||||
)
|
||||
|
||||
aq, aq_sf = scaled_fp4_grouped_quantize(
|
||||
hidden_states,
|
||||
masked_m,
|
||||
input_global_scale,
|
||||
)
|
||||
|
||||
assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
|
||||
assert w1.shape[-1] * 2 == k, (
|
||||
f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}"
|
||||
)
|
||||
assert w2.shape[-2:] == (
|
||||
k,
|
||||
n // 2,
|
||||
), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n // 2)}"
|
||||
|
||||
assert w1_alpha.shape == (num_experts,), (
|
||||
f"w1_alpha must be (l,), got {w1_alpha.shape}"
|
||||
)
|
||||
assert a2_global_scale.shape == (num_experts,), (
|
||||
f"a2_global_scale must be (l,), got {a2_global_scale.shape}"
|
||||
)
|
||||
assert w2_alpha.shape == (num_experts,), (
|
||||
f"w2_alpha must be (l,), got {w2_alpha.shape}"
|
||||
)
|
||||
|
||||
workspace = workspace.permute(1, 2, 0) # requirement of kernel
|
||||
sf_vec_size = 16
|
||||
assert aq_sf.dtype == torch.float8_e4m3fn
|
||||
assert aq.dtype == torch.uint8
|
||||
ab_dtype = "float4_e2m1fn"
|
||||
sf_dtype = "float8_e4m3fn"
|
||||
|
||||
if isinstance(hidden_states, tuple):
|
||||
c_dtype = "bfloat16"
|
||||
else:
|
||||
c_dtype = get_cute_dtype(hidden_states)
|
||||
|
||||
# Gemm1
|
||||
flashinfer_cutedsl_grouped_gemm_nt_masked(
|
||||
(aq, aq_sf),
|
||||
(w1.permute(1, 2, 0), w1_blockscale),
|
||||
workspace,
|
||||
masked_m,
|
||||
ab_dtype=ab_dtype,
|
||||
sf_dtype=sf_dtype,
|
||||
c_dtype=c_dtype,
|
||||
sf_vec_size=sf_vec_size,
|
||||
alpha=w1_alpha.view(1, 1, num_experts),
|
||||
alpha_dtype=get_cute_dtype(w1_alpha),
|
||||
) # in logical [m, n, l]
|
||||
|
||||
# SILU and quantization
|
||||
diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize(
|
||||
workspace.permute(2, 0, 1),
|
||||
masked_m,
|
||||
a2_global_scale,
|
||||
)
|
||||
|
||||
# Gemm2
|
||||
out = out.permute(1, 2, 0) # requirement of kernel
|
||||
flashinfer_cutedsl_grouped_gemm_nt_masked(
|
||||
(diq, diq_sf),
|
||||
(w2.permute(1, 2, 0), w2_blockscale),
|
||||
out,
|
||||
masked_m,
|
||||
ab_dtype=ab_dtype,
|
||||
sf_dtype=sf_dtype,
|
||||
c_dtype=c_dtype,
|
||||
sf_vec_size=sf_vec_size,
|
||||
alpha=w2_alpha.view(1, 1, num_experts),
|
||||
alpha_dtype=get_cute_dtype(w2_alpha),
|
||||
) # in logical [m, k, l]
|
||||
out = out.permute(2, 0, 1)
|
||||
with autotune(_is_fi_autotuning):
|
||||
flashinfer_cute_dsl_fused_moe_nvfp4(
|
||||
x=hidden_states,
|
||||
x_sf=x_sf,
|
||||
token_selected_experts=topk_ids.to(torch.int32),
|
||||
token_final_scales=topk_weights.float(),
|
||||
w1_weight=w1,
|
||||
w1_weight_sf=self.w1_scale,
|
||||
w1_alpha=self.g1_alphas,
|
||||
fc2_input_scale=self.a2_gscale,
|
||||
w2_weight=w2,
|
||||
w2_weight_sf=self.w2_scale,
|
||||
w2_alpha=self.g2_alphas,
|
||||
num_experts=self.global_num_experts,
|
||||
top_k=self.topk,
|
||||
num_local_experts=self.local_num_experts,
|
||||
local_expert_offset=self.local_expert_offset,
|
||||
moe_output=output,
|
||||
)
|
||||
|
||||
@@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
|
||||
prepare_nvfp4_moe_layer_for_flashinfer_cutedsl,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
@@ -41,6 +42,7 @@ class NvFp4MoeBackend(Enum):
|
||||
FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"
|
||||
FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS"
|
||||
FLASHINFER_CUTEDSL = "FLASHINFER_CUTEDSL"
|
||||
FLASHINFER_CUTEDSL_BATCHED = "FLASHINFER_CUTEDSL_BATCHED"
|
||||
VLLM_CUTLASS = "VLLM_CUTLASS"
|
||||
MARLIN = "MARLIN"
|
||||
|
||||
@@ -49,6 +51,7 @@ FLASHINFER_NVFP4_MOE_BACKENDS = [
|
||||
NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED,
|
||||
]
|
||||
|
||||
fi_2_vllm_backend_map: dict[FlashinferMoeBackend, NvFp4MoeBackend] = {
|
||||
@@ -95,6 +98,13 @@ def backend_to_kernel_cls(
|
||||
|
||||
return [FlashInferCuteDSLExperts]
|
||||
|
||||
elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED:
|
||||
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_batched_moe import ( # noqa: E501
|
||||
FlashInferCuteDSLBatchedExperts,
|
||||
)
|
||||
|
||||
return [FlashInferCuteDSLBatchedExperts]
|
||||
|
||||
elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
CutlassExpertsFp4,
|
||||
@@ -143,6 +153,7 @@ def select_nvfp4_moe_backend(
|
||||
AVAILABLE_BACKENDS = [
|
||||
NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
||||
NvFp4MoeBackend.VLLM_CUTLASS,
|
||||
NvFp4MoeBackend.MARLIN,
|
||||
@@ -198,6 +209,12 @@ def select_nvfp4_moe_backend(
|
||||
runner_backend = config.moe_backend
|
||||
if runner_backend != "auto":
|
||||
requested_backend = map_nvfp4_backend(runner_backend)
|
||||
# For batched activation format, use batched variant if available.
|
||||
if (
|
||||
activation_format == mk.FusedMoEActivationFormat.BatchedExperts
|
||||
and requested_backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL
|
||||
):
|
||||
requested_backend = NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED
|
||||
return _return_or_raise(
|
||||
requested_backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
@@ -288,7 +305,28 @@ def convert_to_nvfp4_moe_kernel_format(
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
if (
|
||||
if nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
|
||||
(
|
||||
w13,
|
||||
w13_scale,
|
||||
w13_scale_2,
|
||||
a13_scale,
|
||||
w2,
|
||||
w2_scale,
|
||||
w2_scale_2,
|
||||
a2_scale,
|
||||
) = prepare_nvfp4_moe_layer_for_flashinfer_cutedsl(
|
||||
layer=layer,
|
||||
w13=w13,
|
||||
w13_scale=w13_scale,
|
||||
w13_scale_2=w13_scale_2,
|
||||
a13_scale=a13_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
w2_scale_2=w2_scale_2,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
elif (
|
||||
nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS
|
||||
or nvfp4_backend == NvFp4MoeBackend.VLLM_CUTLASS
|
||||
):
|
||||
@@ -380,7 +418,13 @@ def make_nvfp4_moe_quant_config(
|
||||
# NOTE(rob): this is a hack until the MoE kernels
|
||||
# create their own quant configs. TRTLLM kernel
|
||||
# does not accept swizzled input quant scales.
|
||||
is_nvfp4_scale_swizzled=(backend != NvFp4MoeBackend.FLASHINFER_TRTLLM),
|
||||
is_nvfp4_scale_swizzled=(
|
||||
backend
|
||||
not in (
|
||||
NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -60,6 +60,100 @@ def reorder_w1w3_to_w3w1(
|
||||
)
|
||||
|
||||
|
||||
def interleave_linear_and_gate(
|
||||
x: torch.Tensor,
|
||||
group_size: int = 64,
|
||||
dim: int = -1,
|
||||
) -> torch.Tensor:
|
||||
"""Interleave gate and linear weight rows for CuteDSL wrapper."""
|
||||
sizes = x.size()
|
||||
dim = dim % x.dim()
|
||||
assert sizes[dim] % (group_size * 2) == 0, (
|
||||
f"dim {dim} size {sizes[dim]} must be divisible by {group_size * 2}"
|
||||
)
|
||||
prev_sizes = sizes[:dim]
|
||||
post_sizes = sizes[dim + 1 :]
|
||||
x = x.view(*prev_sizes, 2, sizes[dim] // (group_size * 2), group_size, *post_sizes)
|
||||
x = x.transpose(dim, dim + 1).contiguous().view(*sizes)
|
||||
return x
|
||||
|
||||
|
||||
def prepare_nvfp4_moe_layer_for_flashinfer_cutedsl(
|
||||
layer: "FusedMoE",
|
||||
w13: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w13_scale_2: torch.Tensor,
|
||||
a13_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_scale_2: torch.Tensor,
|
||||
a2_scale: torch.Tensor,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
"""Prepare weights for the CuteDSL wrapper-based NvFP4 MoE backend.
|
||||
|
||||
Converts weight scale factors to MMA layout expected by CuteDslMoEWrapper,
|
||||
and interleaves w13 gate/linear rows.
|
||||
"""
|
||||
from flashinfer.cute_dsl.utils import convert_sf_to_mma_layout
|
||||
|
||||
# Global scaling factors (same as other FlashInfer backends).
|
||||
num_experts = w13.shape[0]
|
||||
a13_scale = a13_scale.max().to(torch.float32).expand(num_experts)
|
||||
a2_scale = a2_scale.max().to(torch.float32).expand(num_experts)
|
||||
|
||||
half = w13.shape[1] // 2
|
||||
w13 = torch.cat([w13[:, half:], w13[:, :half]], dim=1)
|
||||
w13_scale = torch.cat([w13_scale[:, half:], w13_scale[:, :half]], dim=1)
|
||||
|
||||
# Interleave up/gate rows for w13 weights and scales.
|
||||
w13 = interleave_linear_and_gate(w13, group_size=64, dim=1)
|
||||
w13_scale = interleave_linear_and_gate(w13_scale, group_size=64, dim=1)
|
||||
|
||||
# Convert w13 scale factors: linear → swizzled → MMA layout.
|
||||
w13_scale = swizzle_blockscale(w13_scale)
|
||||
E, M_padded, K_sf_padded = w13_scale.shape
|
||||
w13_scale_flat = w13_scale.reshape(E * M_padded, K_sf_padded)
|
||||
w13_scale = convert_sf_to_mma_layout(
|
||||
w13_scale_flat,
|
||||
m=M_padded,
|
||||
k=K_sf_padded * 16,
|
||||
num_groups=E,
|
||||
sf_vec_size=16,
|
||||
)
|
||||
|
||||
# Convert w2 scale factors: linear → swizzled → MMA layout.
|
||||
w2_scale = swizzle_blockscale(w2_scale)
|
||||
E, M_padded, K_sf_padded = w2_scale.shape
|
||||
w2_scale_flat = w2_scale.reshape(E * M_padded, K_sf_padded)
|
||||
w2_scale = convert_sf_to_mma_layout(
|
||||
w2_scale_flat,
|
||||
m=M_padded,
|
||||
k=K_sf_padded * 16,
|
||||
num_groups=E,
|
||||
sf_vec_size=16,
|
||||
)
|
||||
|
||||
return (
|
||||
w13,
|
||||
w13_scale,
|
||||
w13_scale_2,
|
||||
a13_scale,
|
||||
w2,
|
||||
w2_scale,
|
||||
w2_scale_2,
|
||||
a2_scale,
|
||||
)
|
||||
|
||||
|
||||
def prepare_static_weights_for_trtllm_fp4_moe(
|
||||
# args_dequant,
|
||||
# args,
|
||||
@@ -221,7 +315,7 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
|
||||
NvFp4MoeBackend.VLLM_CUTLASS,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
||||
NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED,
|
||||
]
|
||||
|
||||
# Reorder [w1, w3] to [w3, w1] for FI NVFP4 MoE kernels.
|
||||
|
||||
@@ -128,6 +128,12 @@ scaled_fp4_grouped_quantize = _lazy_import_wrapper(
|
||||
nvfp4_block_scale_interleave = _lazy_import_wrapper(
|
||||
"flashinfer.fp4_quantization", "block_scale_interleave"
|
||||
)
|
||||
flashinfer_cute_dsl_fused_moe_nvfp4 = _lazy_import_wrapper(
|
||||
"flashinfer", "cute_dsl_fused_moe_nvfp4"
|
||||
)
|
||||
flashinfer_convert_sf_to_mma_layout = _lazy_import_wrapper(
|
||||
"flashinfer.cute_dsl.utils", "convert_sf_to_mma_layout"
|
||||
)
|
||||
trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
|
||||
"flashinfer", "trtllm_fp4_block_scale_moe"
|
||||
)
|
||||
@@ -252,6 +258,15 @@ def has_flashinfer_cutedsl_grouped_gemm_nt_masked() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_cutedsl_moe_nvfp4() -> bool:
|
||||
"""Return ``True`` if FlashInfer cute_dsl_fused_moe_nvfp4 is available."""
|
||||
if not has_flashinfer_cutedsl():
|
||||
return False
|
||||
mod = _get_submodule("flashinfer")
|
||||
return mod is not None and hasattr(mod, "cute_dsl_fused_moe_nvfp4")
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_nvidia_artifactory() -> bool:
|
||||
"""Return `True` if NVIDIA's artifactory is accessible.
|
||||
@@ -768,6 +783,8 @@ __all__ = [
|
||||
"silu_and_mul_scaled_nvfp4_experts_quantize",
|
||||
"scaled_fp4_grouped_quantize",
|
||||
"nvfp4_block_scale_interleave",
|
||||
"flashinfer_cute_dsl_fused_moe_nvfp4",
|
||||
"flashinfer_convert_sf_to_mma_layout",
|
||||
"trtllm_fp4_block_scale_moe",
|
||||
"autotune",
|
||||
"has_flashinfer_moe",
|
||||
@@ -776,6 +793,7 @@ __all__ = [
|
||||
"has_flashinfer_nvlink_one_sided",
|
||||
"has_flashinfer_cutlass_fused_moe",
|
||||
"has_flashinfer_cutedsl_grouped_gemm_nt_masked",
|
||||
"has_flashinfer_cutedsl_moe_nvfp4",
|
||||
"has_flashinfer_fp8_blockscale_gemm",
|
||||
"has_nvidia_artifactory",
|
||||
"supports_trtllm_attention",
|
||||
|
||||
Reference in New Issue
Block a user