Major refactor to eliminate all post-load hacks: - deepseek_v4.py: use upstream model with NVFP4 weight mapper only (gate_proj→w1, up_proj→w3, down_proj→w2, .self_attn→.attn, .mlp→.ffn) - Add CuTeDSLMoEExperts as a FusedMoEExpertsModular subclass that wraps our CuTeDSL runner as a proper vLLM MoE backend - Register CUTEDSL backend in the NVFP4 oracle - Use ModelOptNvFp4Config for quantization dispatch (not DeepseekV4FP8Config) - ModelOptNvFp4LinearMethod handles NVFP4 attention/shared expert projections - Remove nvfp4_cutedsl.py, cutedsl_quant_method.py, utils.py from Dockerfile - CuTeDSL runner moved to cutedsl/runner.py for clean imports - cos_sin_cache float32 fix in deepseek_v4_attention.py No more monkey-patching, no _convert_nvfp4_post_load, no CuTeDSLNvfp4Method.
536 lines
18 KiB
Python
536 lines
18 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from enum import Enum
|
|
|
|
import torch
|
|
|
|
import vllm.envs as envs
|
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
|
from vllm.config.kernel import MoEBackend
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
|
maybe_make_prepare_finalize,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe.config import (
|
|
FusedMoEConfig,
|
|
FusedMoEQuantConfig,
|
|
nvfp4_moe_quant_config,
|
|
nvfp4_w4a16_moe_quant_config,
|
|
)
|
|
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
|
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
|
|
prepare_nvfp4_moe_layer_for_flashinfer_cutedsl,
|
|
)
|
|
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
|
FlashinferMoeBackend,
|
|
get_flashinfer_moe_backend,
|
|
)
|
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
|
prepare_nvfp4_moe_layer_for_marlin,
|
|
)
|
|
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import (
|
|
kE2M1ToFloat_handle,
|
|
)
|
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|
QuantKey,
|
|
)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class NvFp4MoeBackend(Enum):
|
|
FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"
|
|
FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS"
|
|
FLASHINFER_CUTEDSL = "FLASHINFER_CUTEDSL"
|
|
FLASHINFER_CUTEDSL_BATCHED = "FLASHINFER_CUTEDSL_BATCHED"
|
|
VLLM_CUTLASS = "VLLM_CUTLASS"
|
|
MARLIN = "MARLIN"
|
|
CUTEDSL = "CUTEDSL"
|
|
EMULATION = "EMULATION"
|
|
|
|
|
|
FLASHINFER_NVFP4_MOE_BACKENDS = [
|
|
NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
|
NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
|
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
|
NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED,
|
|
]
|
|
|
|
CUTEDSL_NVFP4_MOE_BACKENDS = [
|
|
NvFp4MoeBackend.CUTEDSL,
|
|
]
|
|
|
|
fi_2_vllm_backend_map: dict[FlashinferMoeBackend, NvFp4MoeBackend] = {
|
|
FlashinferMoeBackend.CUTLASS: NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
|
FlashinferMoeBackend.TENSORRT_LLM: NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
|
FlashinferMoeBackend.CUTEDSL: NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
|
}
|
|
|
|
|
|
def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool:
|
|
# Checks whether `backend` supports quantizing with scaling factors
|
|
# of all experts in Expert Parallel Mode when all experts are not
|
|
# on the same rank.
|
|
|
|
return backend in FLASHINFER_NVFP4_MOE_BACKENDS or backend in CUTEDSL_NVFP4_MOE_BACKENDS
|
|
|
|
|
|
def backend_to_kernel_cls(
|
|
backend: NvFp4MoeBackend,
|
|
) -> list[type[mk.FusedMoEExperts]]:
|
|
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
|
from vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe import (
|
|
TrtLlmNvFp4ExpertsModular,
|
|
TrtLlmNvFp4ExpertsMonolithic,
|
|
)
|
|
|
|
# NOTE: prefer Monolthic > Modular, so return Monolithic first.
|
|
return [
|
|
TrtLlmNvFp4ExpertsMonolithic,
|
|
TrtLlmNvFp4ExpertsModular,
|
|
]
|
|
|
|
elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
|
|
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutlass_moe import ( # noqa: E501
|
|
FlashInferExperts,
|
|
)
|
|
|
|
return [FlashInferExperts]
|
|
|
|
elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
|
|
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_moe import ( # noqa: E501
|
|
FlashInferCuteDSLExperts,
|
|
)
|
|
|
|
return [FlashInferCuteDSLExperts]
|
|
|
|
elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED:
|
|
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_batched_moe import ( # noqa: E501
|
|
FlashInferCuteDSLBatchedExperts,
|
|
)
|
|
|
|
return [FlashInferCuteDSLBatchedExperts]
|
|
|
|
elif backend == NvFp4MoeBackend.CUTEDSL:
|
|
from vllm.model_executor.layers.fused_moe.experts.cutedsl_moe import ( # noqa: E501
|
|
CuTeDSLMoEExperts,
|
|
)
|
|
|
|
return [CuTeDSLMoEExperts]
|
|
|
|
elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
|
|
from vllm.model_executor.layers.fused_moe.experts.cutlass_moe import (
|
|
CutlassExpertsFp4,
|
|
)
|
|
|
|
return [CutlassExpertsFp4]
|
|
|
|
elif backend == NvFp4MoeBackend.MARLIN:
|
|
from vllm.model_executor.layers.fused_moe.experts.marlin_moe import (
|
|
MarlinExperts,
|
|
)
|
|
|
|
return [MarlinExperts]
|
|
elif backend == NvFp4MoeBackend.EMULATION:
|
|
from vllm.model_executor.layers.fused_moe.experts.nvfp4_emulation_moe import (
|
|
Nvfp4QuantizationEmulationTritonExperts,
|
|
)
|
|
|
|
return [Nvfp4QuantizationEmulationTritonExperts]
|
|
else:
|
|
raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}")
|
|
|
|
|
|
def map_nvfp4_backend(runner_backend: MoEBackend) -> NvFp4MoeBackend:
|
|
"""Map user's MoEBackend to NvFp4MoeBackend."""
|
|
mapping = {
|
|
"cutlass": NvFp4MoeBackend.VLLM_CUTLASS,
|
|
"flashinfer_trtllm": NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
|
"flashinfer_cutlass": NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
|
"flashinfer_cutedsl": NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
|
"cutedsl": NvFp4MoeBackend.CUTEDSL,
|
|
"marlin": NvFp4MoeBackend.MARLIN,
|
|
"emulation": NvFp4MoeBackend.EMULATION,
|
|
}
|
|
if backend := mapping.get(runner_backend):
|
|
return backend
|
|
raise ValueError(
|
|
f"moe_backend='{runner_backend}' is not supported for NvFP4 MoE. "
|
|
f"Expected one of {list(mapping.keys())}."
|
|
)
|
|
|
|
|
|
def select_nvfp4_moe_backend(
|
|
config: FusedMoEConfig,
|
|
weight_key: QuantKey | None,
|
|
activation_key: QuantKey | None,
|
|
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
|
|
"""
|
|
Select the primary NvFP4 MoE backend
|
|
Note: Shape-specific fallbacks may still occur at runtime.
|
|
"""
|
|
|
|
# NOTE: the kernels are selected in the following order.
|
|
AVAILABLE_BACKENDS = [
|
|
NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
|
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
|
NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED,
|
|
NvFp4MoeBackend.CUTEDSL,
|
|
NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
|
NvFp4MoeBackend.VLLM_CUTLASS,
|
|
NvFp4MoeBackend.MARLIN,
|
|
NvFp4MoeBackend.EMULATION,
|
|
]
|
|
|
|
use_batched = config.moe_parallel_config.use_batched_activation_format
|
|
activation_format = (
|
|
mk.FusedMoEActivationFormat.BatchedExperts
|
|
if use_batched
|
|
else mk.FusedMoEActivationFormat.Standard
|
|
)
|
|
|
|
def _make_log_backend(backend: NvFp4MoeBackend):
|
|
available_backend_strs = [b.value for b in AVAILABLE_BACKENDS]
|
|
return (
|
|
f"Using '{backend.value}' NvFp4 MoE backend out "
|
|
f"of potential backends: {available_backend_strs}."
|
|
)
|
|
|
|
def _make_log_unsupported(backend: NvFp4MoeBackend, reason: str | None) -> str:
|
|
if reason:
|
|
return (
|
|
f"NvFp4 MoE backend '{backend.value}' does not support the "
|
|
f"deployment configuration since {reason}."
|
|
)
|
|
else:
|
|
return (
|
|
f"NvFp4 MoE backend '{backend.value}' does not support the "
|
|
"deployment configuration."
|
|
)
|
|
|
|
def _return_or_raise(
|
|
backend: NvFp4MoeBackend,
|
|
config: FusedMoEConfig,
|
|
weight_key: QuantKey | None,
|
|
activation_key: QuantKey | None,
|
|
activation_format: mk.FusedMoEActivationFormat,
|
|
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
|
|
for k_cls in backend_to_kernel_cls(backend):
|
|
supported, reason = k_cls.is_supported_config(
|
|
k_cls, config, weight_key, activation_key, activation_format
|
|
)
|
|
if supported:
|
|
logger.info_once(_make_log_backend(backend))
|
|
return backend, k_cls
|
|
|
|
raise ValueError(_make_log_unsupported(backend, reason))
|
|
|
|
# Handle explicit moe_backend from user.
|
|
runner_backend = config.moe_backend
|
|
if runner_backend != "auto":
|
|
requested_backend = map_nvfp4_backend(runner_backend)
|
|
# For batched activation format, use batched variant if available.
|
|
if (
|
|
activation_format == mk.FusedMoEActivationFormat.BatchedExperts
|
|
and requested_backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL
|
|
):
|
|
requested_backend = NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED
|
|
return _return_or_raise(
|
|
requested_backend, config, weight_key, activation_key, activation_format
|
|
)
|
|
|
|
if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"):
|
|
if not envs.VLLM_USE_FLASHINFER_MOE_FP4:
|
|
# If the user rejects FlashInfer remove those backends.
|
|
for b in FLASHINFER_NVFP4_MOE_BACKENDS:
|
|
AVAILABLE_BACKENDS.remove(b)
|
|
|
|
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
|
|
# If user is explicit about backend, validate it.
|
|
backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()]
|
|
return _return_or_raise(
|
|
backend, config, weight_key, activation_key, activation_format
|
|
)
|
|
else:
|
|
# If the user is not explicit about the backend, try each.
|
|
for backend in FLASHINFER_NVFP4_MOE_BACKENDS:
|
|
for k_cls in backend_to_kernel_cls(backend):
|
|
supported, reason = k_cls.is_supported_config(
|
|
k_cls,
|
|
config,
|
|
weight_key,
|
|
activation_key,
|
|
activation_format,
|
|
)
|
|
if supported:
|
|
logger.info_once(_make_log_backend(backend))
|
|
return backend, k_cls
|
|
else:
|
|
logger.debug_once(_make_log_unsupported(backend, reason))
|
|
|
|
raise NotImplementedError(
|
|
"Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no "
|
|
"FlashInfer NVFP4 MoE backend supports the configuration."
|
|
)
|
|
|
|
if envs.VLLM_TEST_FORCE_FP8_MARLIN:
|
|
backend = NvFp4MoeBackend.MARLIN
|
|
return _return_or_raise(
|
|
backend, config, weight_key, activation_key, activation_format
|
|
)
|
|
|
|
# Select kernels in order of backend.
|
|
for backend in AVAILABLE_BACKENDS:
|
|
for k_cls in backend_to_kernel_cls(backend):
|
|
supported, reason = k_cls.is_supported_config(
|
|
k_cls,
|
|
config,
|
|
weight_key,
|
|
activation_key,
|
|
activation_format,
|
|
)
|
|
if supported:
|
|
logger.info_once(_make_log_backend(backend))
|
|
return backend, k_cls
|
|
else:
|
|
logger.debug_once(_make_log_unsupported(backend, reason))
|
|
|
|
raise NotImplementedError(
|
|
"No NvFp4 MoE backend supports the deployment configuration."
|
|
)
|
|
|
|
|
|
def convert_to_nvfp4_moe_kernel_format(
|
|
nvfp4_backend: NvFp4MoeBackend,
|
|
layer: torch.nn.Module,
|
|
w13: torch.Tensor,
|
|
w13_scale: torch.Tensor,
|
|
w13_scale_2: torch.Tensor,
|
|
a13_scale: torch.Tensor | None,
|
|
w2: torch.Tensor,
|
|
w2_scale: torch.Tensor,
|
|
w2_scale_2: torch.Tensor,
|
|
a2_scale: torch.Tensor | None,
|
|
is_act_and_mul: bool,
|
|
) -> tuple[
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
]:
|
|
if nvfp4_backend == NvFp4MoeBackend.CUTEDSL:
|
|
# CuTeDSL kernel handles weight conversion in its own
|
|
# process_weights_after_loading. Pass through raw weights.
|
|
# Compute inverse activation scales for the quant config.
|
|
if a13_scale is not None:
|
|
a13_scale = 1.0 / a13_scale
|
|
if a2_scale is not None:
|
|
a2_scale = 1.0 / a2_scale
|
|
elif nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
|
|
(
|
|
w13,
|
|
w13_scale,
|
|
w13_scale_2,
|
|
a13_scale,
|
|
w2,
|
|
w2_scale,
|
|
w2_scale_2,
|
|
a2_scale,
|
|
) = prepare_nvfp4_moe_layer_for_flashinfer_cutedsl(
|
|
layer=layer,
|
|
w13=w13,
|
|
w13_scale=w13_scale,
|
|
w13_scale_2=w13_scale_2,
|
|
a13_scale=a13_scale,
|
|
w2=w2,
|
|
w2_scale=w2_scale,
|
|
w2_scale_2=w2_scale_2,
|
|
a2_scale=a2_scale,
|
|
)
|
|
elif (
|
|
nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS
|
|
or nvfp4_backend == NvFp4MoeBackend.VLLM_CUTLASS
|
|
):
|
|
(
|
|
w13,
|
|
w13_scale,
|
|
w13_scale_2,
|
|
a13_scale,
|
|
w2,
|
|
w2_scale,
|
|
w2_scale_2,
|
|
a2_scale,
|
|
) = prepare_nvfp4_moe_layer_for_fi_or_cutlass(
|
|
backend=nvfp4_backend,
|
|
layer=layer,
|
|
w13=w13,
|
|
w13_scale=w13_scale,
|
|
w13_scale_2=w13_scale_2,
|
|
a13_scale=a13_scale,
|
|
w2=w2,
|
|
w2_scale=w2_scale,
|
|
w2_scale_2=w2_scale_2,
|
|
a2_scale=a2_scale,
|
|
is_act_and_mul=is_act_and_mul,
|
|
)
|
|
elif nvfp4_backend == NvFp4MoeBackend.MARLIN:
|
|
a13_scale = None
|
|
a2_scale = None
|
|
(
|
|
w13,
|
|
w13_scale,
|
|
w13_scale_2,
|
|
w2,
|
|
w2_scale,
|
|
w2_scale_2,
|
|
) = prepare_nvfp4_moe_layer_for_marlin(
|
|
layer=layer,
|
|
w13=w13,
|
|
w13_scale=w13_scale,
|
|
w13_scale_2=w13_scale_2,
|
|
w2=w2,
|
|
w2_scale=w2_scale,
|
|
w2_scale_2=w2_scale_2,
|
|
is_act_and_mul=is_act_and_mul,
|
|
)
|
|
elif nvfp4_backend == NvFp4MoeBackend.EMULATION:
|
|
# Move the E2M1 lookup table to the device now, because
|
|
# `.to(device)` is not allowed during CUDA graph capture.
|
|
kE2M1ToFloat_handle.val = kE2M1ToFloat_handle.val.to(w13.device)
|
|
|
|
if a13_scale is None or a2_scale is None:
|
|
raise ValueError(
|
|
"Activation global scales should not be None, got"
|
|
f" a13_scale={a13_scale}, a2_scale={a2_scale}"
|
|
)
|
|
|
|
if torch.unique(a13_scale).numel() != 1 or torch.unique(a2_scale).numel() != 1:
|
|
logger.warning_once(
|
|
"In NVFP4 linear, the activation global scale for inputs are different"
|
|
" for MOE w13 (gate_up_proj) layer or MOE w2 (down_proj). Using"
|
|
" a13_scale = a13_scale.max() and a2_scale = a2_scale.max()."
|
|
)
|
|
|
|
# 1. We take the max following e.g. quantization/utils/flashinfer_fp4_moe.py.
|
|
# 2. moe_kernel_quantize_input -> ref_nvfp4_quant_dequant
|
|
# use the inverse scale directly (large global scale).
|
|
# NOTE: Before this point, `a13_scale` and `a2_scale` are such that:
|
|
# `FP8_MAX = activation[expert_id].abs().max() * global_scale[expert_id]`,
|
|
# and `global_scale[expert_id]` are small (~1e-4).
|
|
# Taking the largest global scale likely results in overflowing the FP8 range
|
|
# for other experts - other selection strategies may be used.
|
|
a13_scale = 1.0 / a13_scale.max().to(torch.float32)
|
|
a2_scale = 1.0 / a2_scale.max().to(torch.float32)
|
|
else:
|
|
raise ValueError(f"Unknown NvFp4 backend for MoE: {nvfp4_backend}")
|
|
|
|
return (
|
|
w13,
|
|
w13_scale,
|
|
w13_scale_2,
|
|
a13_scale,
|
|
w2,
|
|
w2_scale,
|
|
w2_scale_2,
|
|
a2_scale,
|
|
)
|
|
|
|
|
|
def make_nvfp4_moe_quant_config(
|
|
backend: NvFp4MoeBackend,
|
|
w13_scale: torch.Tensor,
|
|
w2_scale: torch.Tensor,
|
|
w13_scale_2: torch.Tensor,
|
|
w2_scale_2: torch.Tensor,
|
|
a13_scale: torch.Tensor,
|
|
a2_scale: torch.Tensor,
|
|
) -> FusedMoEQuantConfig:
|
|
if backend == NvFp4MoeBackend.MARLIN:
|
|
return nvfp4_w4a16_moe_quant_config(
|
|
g1_alphas=w13_scale_2,
|
|
g2_alphas=w2_scale_2,
|
|
w1_scale=w13_scale,
|
|
w2_scale=w2_scale,
|
|
)
|
|
elif backend == NvFp4MoeBackend.EMULATION:
|
|
return nvfp4_moe_quant_config(
|
|
g1_alphas=w13_scale_2,
|
|
g2_alphas=w2_scale_2,
|
|
a1_gscale=a13_scale,
|
|
a2_gscale=a2_scale,
|
|
w1_scale=w13_scale,
|
|
w2_scale=w2_scale,
|
|
)
|
|
|
|
# Pass w13_scale_2 / w2_scale_2 directly as g1/g2_alphas.
|
|
# The expert's process_weights_after_loading will fuse activation
|
|
# scales in-place. Since the quant config references the same tensor
|
|
# as the registered parameter, EPLB rearrangement stays in sync.
|
|
return nvfp4_moe_quant_config(
|
|
g1_alphas=w13_scale_2,
|
|
g2_alphas=w2_scale_2,
|
|
a1_gscale=(1.0 / a13_scale),
|
|
a2_gscale=(1.0 / a2_scale),
|
|
w1_scale=w13_scale,
|
|
w2_scale=w2_scale,
|
|
# NOTE(rob): this is a hack until the MoE kernels
|
|
# create their own quant configs. TRTLLM kernel
|
|
# does not accept swizzled input quant scales.
|
|
is_scale_swizzled=(
|
|
backend
|
|
not in (
|
|
NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
|
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
|
NvFp4MoeBackend.CUTEDSL,
|
|
)
|
|
),
|
|
)
|
|
|
|
|
|
def make_nvfp4_moe_kernel(
|
|
moe_quant_config: FusedMoEQuantConfig,
|
|
moe_config: FusedMoEConfig,
|
|
experts_cls: type[mk.FusedMoEExperts],
|
|
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
|
) -> mk.FusedMoEKernel:
|
|
# Create Prepare/Finalize.
|
|
prepare_finalize = maybe_make_prepare_finalize(
|
|
moe=moe_config,
|
|
quant_config=moe_quant_config,
|
|
routing_tables=routing_tables,
|
|
allow_new_interface=True,
|
|
use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic),
|
|
)
|
|
assert prepare_finalize is not None
|
|
|
|
logger.info_once("Using %s", prepare_finalize.__class__.__name__)
|
|
|
|
# Create Experts.
|
|
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
|
|
max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
|
|
assert max_num_tokens is not None
|
|
experts = experts_cls(
|
|
moe_config=moe_config,
|
|
quant_config=moe_quant_config,
|
|
max_num_tokens=max_num_tokens,
|
|
num_dispatchers=prepare_finalize.num_dispatchers(),
|
|
)
|
|
else:
|
|
experts = experts_cls(
|
|
moe_config=moe_config,
|
|
quant_config=moe_quant_config,
|
|
)
|
|
|
|
kernel = mk.FusedMoEKernel(
|
|
prepare_finalize,
|
|
experts,
|
|
inplace=False,
|
|
)
|
|
|
|
# TODO(rob): update inplace logic to be part of the kernel.
|
|
return kernel
|