[MoE Refactor][16/N] Apply Refactor to NVFP4 (#31692)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Robert Shaw
2026-01-07 22:46:27 -05:00
committed by GitHub
parent 8dd2419fa9
commit 9f6dcb71ae
15 changed files with 777 additions and 681 deletions

View File

@@ -72,7 +72,6 @@ if HAS_TRITON:
CutlassBatchedExpertsFp8,
CutlassExpertsFp8,
CutlassExpertsW4A8Fp8,
cutlass_moe_fp4,
cutlass_moe_w4a8_fp8,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
@@ -95,7 +94,6 @@ if HAS_TRITON:
"fused_experts",
"get_config_file_name",
"GroupedTopk",
"cutlass_moe_fp4",
"cutlass_moe_w4a8_fp8",
"CutlassExpertsFp8",
"CutlassBatchedExpertsFp8",

View File

@@ -336,6 +336,10 @@ class FusedMoEQuantConfig:
def use_int4_w4a16(self) -> bool:
return self._a1.dtype is None and self._w1.dtype == "int4"
@property
def use_nvfp4_w4a16(self) -> bool:
return self._a1.dtype is None and self._w1.dtype == "nvfp4"
@property
def ocp_mx_scheme(self) -> str | None:
if not hasattr(self, "_ocp_mx_scheme"):
@@ -690,6 +694,25 @@ def nvfp4_moe_quant_config(
)
def nvfp4_w4a16_moe_quant_config(
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for 16-but activations and nvp4 weights.
"""
return FusedMoEQuantConfig.make(
quant_dtype=None,
w1_scale=w1_scale,
w2_scale=w2_scale,
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
weight_dtype="nvfp4",
)
def int4_w4a16_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,

View File

@@ -706,68 +706,6 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
)
def cutlass_moe_fp4(
a: torch.Tensor,
w1_fp4: torch.Tensor,
w2_fp4: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_config: FusedMoEQuantConfig,
m: int,
n: int,
k: int,
e: int,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
assert expert_map is None, (
"Expert Parallelism / expert_map "
"is currently not supported for "
"ModelOptNvFp4FusedMoE's cutlass_moe_fp4."
)
# TODO(bnell): this feels a bit hacky
# NVFP4 requires two levels of quantization, which involves
# computing some scaling factors dynamically. This makes it
# incompatible with the typical prepare -> MoE -> finalize
# pipeline. Move the quantization logic into the MoE body.
quant_config = FusedMoEQuantConfig.make(
quant_dtype=None, # skip quantization in prepare/finalize
per_act_token_quant=quant_config.per_act_token_quant,
per_out_ch_quant=quant_config.per_out_ch_quant,
block_shape=quant_config.block_shape,
g1_alphas=quant_config.g1_alphas,
g2_alphas=quant_config.g2_alphas,
a1_gscale=quant_config.a1_gscale,
a2_gscale=quant_config.a2_gscale,
w1_scale=quant_config.w1_scale,
w2_scale=quant_config.w2_scale,
)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4(
max_experts_per_worker=e,
out_dtype=a.dtype,
quant_config=quant_config,
use_batched_format=False,
),
)
return fn(
hidden_states=a,
w1=w1_fp4,
w2=w2_fp4,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
activation="silu",
global_num_experts=e,
expert_map=None,
apply_router_weight_on_input=apply_router_weight_on_input,
)
# W4A8
def run_cutlass_moe_w4a8_fp8(
output: torch.Tensor,

View File

@@ -335,42 +335,3 @@ def flashinfer_cutedsl_moe_masked(
alpha_dtype=get_cute_dtype(w2_alpha),
) # in logical [m, k, l]
out = out.permute(2, 0, 1)
def flashinfer_cutedsl_moe_fp4(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_config: FusedMoEQuantConfig,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
)
fused_experts = mk.FusedMoEModularKernel(
create_flashinfer_prepare_finalize(use_dp=False), # could be swapped later
FlashInferCuteDSLExperts(
out_dtype=hidden_states.dtype,
quant_config=quant_config,
),
)
return fused_experts(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)

View File

@@ -355,21 +355,17 @@ def create_flashinfer_prepare_finalize(
use_deepseek_fp8_block_scale: bool = False,
) -> FlashInferCutlassMoEPrepareAndFinalize | MoEPrepareAndFinalizeNoEP:
"""Factory function to create the appropriate FlashInfer implementation."""
# TODO(rob): migrate non-DP cases to MoEPrepareAndFinalizeNoEP
# once we complete the FP8 refactor.
if use_nvfp4:
if enable_alltoallv:
return FlashInferAllToAllMoEPrepareAndFinalize(use_dp)
else:
return FlashInferAllGatherMoEPrepareAndFinalize(use_dp)
# FP8 DP path currently supported via AllGather.
if use_dp:
if enable_alltoallv:
assert use_nvfp4
return FlashInferAllToAllMoEPrepareAndFinalize(use_dp)
return FlashInferAllGatherMoEPrepareAndFinalize(
use_dp=True,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
)
else:
# NOTE(rob): CUTLASS FP8 block quant executes the input
# quantzation and grouped gemm in a single kernel.
return MoEPrepareAndFinalizeNoEP(defer_input_quant=use_deepseek_fp8_block_scale)
# CUTLASS FP8 BLOCK and CUTLASS NVFP4 apply input quantization
# in a single call with the MoE experts kernel.
defer_input_quant = use_deepseek_fp8_block_scale or use_nvfp4
return MoEPrepareAndFinalizeNoEP(defer_input_quant=defer_input_quant)

View File

@@ -540,9 +540,10 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
# TODO (varun) : Enable activation quantization
assert (
quant_config.use_mxfp4_w4a16
or quant_config.use_nvfp4_w4a16
or quant_config.use_int4_w4a16
or quant_config.use_fp8_w8a16
), "Supports only mxfp4_w4a16, int4_w4a16 or fp8_w8a16"
), "Supports only {mxfp,nvfp,int}4_w4a16 or fp8_w8a16"
self.w13_g_idx = w13_g_idx
self.w2_g_idx = w2_g_idx
self.w13_g_idx_sort_indices = w13_g_idx_sort_indices
@@ -555,7 +556,7 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
# uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4
if self.quant_config.use_int4_w4a16:
return scalar_types.uint4b8.id
elif self.quant_config.use_mxfp4_w4a16:
elif self.quant_config.use_mxfp4_w4a16 or self.quant_config.use_nvfp4_w4a16:
return scalar_types.float4_e2m1f.id
elif (
self.quant_config.use_fp8_w8a16
@@ -692,6 +693,8 @@ class MarlinExperts(MarlinExpertsBase):
gating_output=None,
topk_weights=topk_weights,
topk_ids=topk_ids,
global_scale1=self.g1_alphas,
global_scale2=self.g2_alphas,
quant_type_id=self.quant_type_id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,

View File

@@ -38,9 +38,6 @@ from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimula
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
is_flashinfer_supporting_global_sf,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
from vllm.utils.math_utils import cdiv, round_up
@@ -1125,14 +1122,9 @@ class FusedMoE(CustomOp):
global_expert_id = expert_id
expert_id = self._map_global_expert_id_to_local_expert_id(global_expert_id)
allow_flashinfer = getattr(self.quant_method, "allow_flashinfer", False)
moe_backend = getattr(self.quant_method, "flashinfer_moe_backend", None)
use_global_sf = (
allow_flashinfer
and is_flashinfer_supporting_global_sf(moe_backend)
getattr(self.quant_method, "use_global_sf", False)
and "input_scale" in weight_name
and quant_method_name == "ModelOptNvFp4FusedMoE"
)
if expert_id == -1 and not use_global_sf:

View File

@@ -0,0 +1,280 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
nvfp4_moe_quant_config,
nvfp4_w4a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
is_flashinfer_fp4_cutedsl_moe_available,
is_flashinfer_fp4_cutlass_moe_available,
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
get_flashinfer_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
is_fp4_marlin_supported,
prepare_nvfp4_moe_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported,
)
logger = init_logger(__name__)
class NvFp4MoeBackend(Enum):
FLASHINFER_CUTLASS = "FlashInfer CUTLASS"
FLASHINFER_TRTLLM = "FlashInfer TRTLLM"
FLASHINFER_CUTEDSL = "FlashInfer CUTEDSL"
VLLM_CUTLASS = "vLLM CUTASS"
MARLIN = "vLLM MARLIN"
FLASHINFER_NVFP4_MOE_BACKENDS = [
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.FLASHINFER_TRTLLM,
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
]
fi_2_vllm_backend_map: dict[FlashinferMoeBackend, NvFp4MoeBackend] = {
FlashinferMoeBackend.CUTLASS: NvFp4MoeBackend.FLASHINFER_CUTLASS,
FlashinferMoeBackend.TENSORRT_LLM: NvFp4MoeBackend.FLASHINFER_TRTLLM,
FlashinferMoeBackend.CUTEDSL: NvFp4MoeBackend.FLASHINFER_CUTEDSL,
}
def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool:
# Checks whether `backend` supports quantizing with scaling factors
# of all experts in Expert Parallel Mode when all experts are not
# on the same rank.
return backend in [
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.FLASHINFER_TRTLLM,
]
def select_nvfp4_moe_backend() -> NvFp4MoeBackend:
def _make_log_backend(backend: NvFp4MoeBackend):
return f"Using {backend.value} backend for NvFp4 MoE"
if cutlass_fp4_supported() and not envs.VLLM_TEST_FORCE_FP8_MARLIN:
allow_flashinfer = (
is_flashinfer_fp4_cutlass_moe_available()
or is_flashinfer_fp4_cutedsl_moe_available()
)
if allow_flashinfer and envs.VLLM_USE_FLASHINFER_MOE_FP4:
backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()]
else:
backend = NvFp4MoeBackend.VLLM_CUTLASS
elif is_fp4_marlin_supported():
backend = NvFp4MoeBackend.MARLIN
else:
raise ValueError("No NvFp4 kernel backend available for NvFp4 MoE.")
# Log warning if FI backend requested but not available.
if (
backend not in FLASHINFER_NVFP4_MOE_BACKENDS
and envs.VLLM_USE_FLASHINFER_MOE_FP4
):
logger.warning_once(
"Requested FlashInfer backend for NvFp4 MoE, but it's not available. "
"Falling back to %s for NvFp4 MoE",
backend.value,
scope="local",
)
else:
logger.info_once(_make_log_backend(backend), scope="local")
return backend
def convert_to_nvfp4_moe_kernel_format(
nvfp4_backend: NvFp4MoeBackend,
layer: torch.nn.Module,
w13: torch.Tensor,
w13_scale: torch.Tensor,
w13_scale_2: torch.Tensor,
a13_scale: torch.Tensor | None,
w2: torch.Tensor,
w2_scale: torch.Tensor,
w2_scale_2: torch.Tensor,
a2_scale: torch.Tensor | None,
is_act_and_mul: bool,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
if (
nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS
or nvfp4_backend == NvFp4MoeBackend.VLLM_CUTLASS
):
(
w13,
w13_scale,
w13_scale_2,
a13_scale,
w2,
w2_scale,
w2_scale_2,
a2_scale,
) = prepare_nvfp4_moe_layer_for_fi_or_cutlass(
backend=nvfp4_backend,
layer=layer,
w13=w13,
w13_scale=w13_scale,
w13_scale_2=w13_scale_2,
a13_scale=a13_scale,
w2=w2,
w2_scale=w2_scale,
w2_scale_2=w2_scale_2,
a2_scale=a2_scale,
is_act_and_mul=is_act_and_mul,
)
elif nvfp4_backend == NvFp4MoeBackend.MARLIN:
a13_scale = None
a2_scale = None
(
w13,
w13_scale,
w13_scale_2,
w2,
w2_scale,
w2_scale_2,
) = prepare_nvfp4_moe_layer_for_marlin(
layer=layer,
w13=w13,
w13_scale=w13_scale,
w13_scale_2=w13_scale_2,
w2=w2,
w2_scale=w2_scale,
w2_scale_2=w2_scale_2,
)
else:
raise ValueError(f"Unknown NvFp4 backend for MoE: {nvfp4_backend}")
return (
w13,
w13_scale,
w13_scale_2,
a13_scale,
w2,
w2_scale,
w2_scale_2,
a2_scale,
)
def make_nvfp4_moe_quant_config(
backend: NvFp4MoeBackend,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
w13_scale_2: torch.Tensor,
w2_scale_2: torch.Tensor,
a13_scale: torch.Tensor,
a2_scale: torch.Tensor,
) -> FusedMoEQuantConfig | None:
UNSUPPORTED = [NvFp4MoeBackend.FLASHINFER_TRTLLM]
if backend in UNSUPPORTED:
return None
elif backend == NvFp4MoeBackend.MARLIN:
return nvfp4_w4a16_moe_quant_config(
g1_alphas=w13_scale_2,
g2_alphas=w2_scale_2,
w1_scale=w13_scale,
w2_scale=w2_scale,
)
g1_alphas = a13_scale * w13_scale_2
g2_alphas = a2_scale * w2_scale_2
return nvfp4_moe_quant_config(
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
a1_gscale=(1.0 / a13_scale),
a2_gscale=(1.0 / a2_scale),
w1_scale=w13_scale,
w2_scale=w2_scale,
)
def make_nvfp4_moe_kernel(
backend: NvFp4MoeBackend,
quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
) -> mk.FusedMoEModularKernel | None:
assert moe_config.dp_size == 1
UNSUPPORTED_BACKENDS = [
# TRTLLM does not use the modular kernl abstraction.
NvFp4MoeBackend.FLASHINFER_TRTLLM,
# CUTEDSL is used with BATCHED (masked) format only.
# TODO: add here once we support dp/ep via the oracle.
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
]
if backend in UNSUPPORTED_BACKENDS:
return None
elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
FlashInferExperts(
out_dtype=moe_config.in_dtype,
quant_config=quant_config,
ep_rank=moe_config.ep_rank,
ep_size=moe_config.ep_size,
tp_rank=moe_config.tp_rank,
tp_size=moe_config.tp_size,
use_dp=False,
use_deepseek_fp8_block_scale=False,
),
)
elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
CutlassExpertsFp4(
out_dtype=moe_config.in_dtype,
# TODO(rob): see what impact this has on expert map?
max_experts_per_worker=moe_config.num_experts,
quant_config=quant_config,
),
)
elif backend == NvFp4MoeBackend.MARLIN:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
MarlinExperts(quant_config=quant_config),
)
else:
raise ValueError(f"Unknown NvFp4 MoE backend: {backend}")

View File

@@ -11,7 +11,6 @@ from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationStrategy,
)
from torch.nn.parameter import Parameter
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
@@ -34,12 +33,8 @@ from vllm.model_executor.layers.fused_moe.config import (
int4_w4afp8_moe_quant_config,
int8_w8a8_moe_quant_config,
int8_w8a16_moe_quant_config,
nvfp4_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
MarlinExperts,
@@ -51,6 +46,15 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
make_fp8_moe_kernel,
select_fp8_moe_backend,
)
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
FLASHINFER_NVFP4_MOE_BACKENDS,
NvFp4MoeBackend,
convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend,
make_nvfp4_moe_kernel,
make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend,
)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS,
WNA16_SUPPORTED_TYPES_MAP,
@@ -58,14 +62,9 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compress
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe,
prepare_static_weights_for_trtllm_fp4_moe,
reorder_w1w3_to_w3w1,
flashinfer_trtllm_fp4_routed_moe,
select_nvfp4_gemm_impl,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
get_flashinfer_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_input_tensor_strategy_moe,
process_fp8_weight_tensor_strategy_moe,
@@ -77,20 +76,15 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new,
marlin_moe_permute_scales,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
convert_bf16_scales_to_fp8,
convert_packed_uint4b8_to_signed_int4_inplace,
swizzle_blockscale,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
normalize_e4m3fn_to_e4m3fnuz,
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import CpuArchEnum, current_platform
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@@ -218,31 +212,19 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
def __init__(self, moe: FusedMoEConfig, layer_name: str | None = None):
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support,
)
if not moe.is_act_and_mul:
raise ValueError(
"CompressedTensorsW4A4Nvfp4MoEMethod does not yet "
"support non gated MoE models."
)
super().__init__(moe)
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin
self.group_size = 16
self.layer_name = layer_name
self.marlin_input_dtype = (
get_marlin_input_dtype(layer_name) if self.use_marlin else None
self.nvfp4_backend = select_nvfp4_moe_backend()
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend
)
self.flashinfer_moe_backend = None
if self.allow_flashinfer:
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
" for CompressedTensorsW4A4Nvfp4MoEMethod."
)
elif self.use_marlin:
logger.info_once("Using Marlin for CompressedTensorsW4A4Nvfp4MoEMethod.")
else:
logger.info_once("Using Cutlass for CompressedTensorsW4A4Nvfp4MoEMethod.")
self.kernel: mk.FusedMoEModularKernel | None = None
def create_weights(
self,
@@ -355,7 +337,13 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
set_weight_attrs(w2_input_scale, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# From packed to weight
"""
Convert NVFP4 MoE weights into kernel format and setup the kernel.
"""
# NOTE(rob): wN_weight_packed -> wN_weight is because ModularKernelMethod
# requires this naming convention. However, the name change breaks
# reloading because the state dict no longer matches disk. Once we
# remove MKM, we should revert this change to ensure compatibility.
layer.w13_weight = torch.nn.Parameter(
layer.w13_weight_packed.data, requires_grad=False
)
@@ -366,144 +354,79 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
)
delattr(layer, "w2_weight_packed")
# reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel.
if self.allow_flashinfer:
w, s = reorder_w1w3_to_w3w1(
layer.w13_weight.data, layer.w13_weight_scale.data, dim=-2
)
layer.w13_weight = torch.nn.Parameter(w, requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(s, requires_grad=False)
if not torch.allclose(
# Use a single gscale for w13.
if self.moe.is_act_and_mul and not torch.allclose(
layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1]
):
logger.warning_once(
"w1_weight_global_scale must match w3_weight_global_scale. "
"Accuracy may be affected."
"Accuracy may be affected.",
)
w13_weight_global_scale = layer.w13_weight_global_scale[:, 0].contiguous()
# Take inverse of global scale saved to disk
layer.w13_weight_scale_2 = torch.nn.Parameter(
1 / layer.w13_weight_global_scale[:, 0], requires_grad=False
# Shuffle weights into the NvFp4 kernel format.
(
w13,
w13_scale,
w13_scale_2,
a13_scale,
w2,
w2_scale,
w2_scale_2,
a2_scale,
) = convert_to_nvfp4_moe_kernel_format(
nvfp4_backend=self.nvfp4_backend,
layer=layer,
w13=layer.w13_weight,
w13_scale=layer.w13_weight_scale,
w13_scale_2=(1.0 / w13_weight_global_scale),
a13_scale=(1.0 / layer.w13_input_global_scale),
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
w2_scale_2=(1.0 / layer.w2_weight_global_scale),
a2_scale=(1.0 / layer.w2_input_global_scale),
is_act_and_mul=self.moe.is_act_and_mul,
)
layer.w2_weight_scale_2 = torch.nn.Parameter(
1 / layer.w2_weight_global_scale.data, requires_grad=False
)
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w13_weight_scale", w13_scale)
replace_parameter(layer, "w2_weight", w2)
replace_parameter(layer, "w2_weight_scale", w2_scale)
layer.w13_weight_scale_2 = w13_scale_2
layer.w2_weight_scale_2 = w2_scale_2
layer.w13_input_scale = a13_scale
layer.w2_input_scale = a2_scale
if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype)
return
# w13
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
w13_input_global_scale = (
layer.w13_input_global_scale.min()
.to(torch.float32)
.expand(layer.num_experts)
)
else:
w13_input_global_scale = layer.w13_input_global_scale.min(dim=1).values.to(
torch.float32
)
layer.g1_alphas = torch.nn.Parameter(
((1 / w13_input_global_scale) * layer.w13_weight_scale_2),
requires_grad=False,
)
layer.w13_input_scale_quant = torch.nn.Parameter(
(w13_input_global_scale), requires_grad=False
)
# w2
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
w2_input_global_scale = (
layer.w2_input_global_scale.min()
.to(torch.float32)
.expand(layer.num_experts)
)
else:
w2_input_global_scale = layer.w2_input_global_scale
layer.g2_alphas = torch.nn.Parameter(
((1 / w2_input_global_scale) * layer.w2_weight_scale_2).to(torch.float32),
requires_grad=False,
)
layer.w2_input_scale_quant = torch.nn.Parameter(
(w2_input_global_scale), requires_grad=False
)
# TensorRT-LLM specific processing
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
# Prepare static weights for TRT-LLM kernel
# alternate: prepare_static_weight_layouts_for_trtllm_moe
(
gemm1_weights_fp4_shuffled,
gemm1_scales_fp4_shuffled,
gemm2_weights_fp4_shuffled,
gemm2_scales_fp4_shuffled,
) = prepare_static_weights_for_trtllm_fp4_moe(
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
layer.w2_weight.size(-2), # hidden_size
layer.w13_weight.size(-2) // 2, # intermediate_size
layer.w13_weight.size(0), # num_experts
)
logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
layer.w13_weight = Parameter(
gemm1_weights_fp4_shuffled, requires_grad=False
)
layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False)
layer.w13_weight_scale = Parameter(
gemm1_scales_fp4_shuffled, requires_grad=False
)
layer.w2_weight_scale = Parameter(
gemm2_scales_fp4_shuffled, requires_grad=False
)
# Additional parameter needed for TRT-LLM
layer.g1_scale_c = Parameter(
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
requires_grad=False,
)
else:
# swizzle weight scales
layer.w13_weight_scale = torch.nn.Parameter(
swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
)
layer.w2_weight_scale = torch.nn.Parameter(
swizzle_blockscale(layer.w2_weight_scale), requires_grad=False
# Initialize the kernel that will be called in apply().
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
use_dp = self.moe.dp_size > 1
if self.moe_quant_config is not None and not use_dp:
self.kernel = make_nvfp4_moe_kernel(
backend=self.nvfp4_backend,
quant_config=self.moe_quant_config,
moe_config=self.moe,
)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if self.use_marlin or (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM]
if self.nvfp4_backend in UNSUPPORTED:
return None
elif not self.allow_flashinfer:
elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
# TP case: avoid convert to ModularKernelMethod - to be refactored.
if self.moe.dp_size == 1:
return None
# For now, fp4 moe only works with the flashinfer dispatcher.
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
self.moe
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
else:
return super().maybe_make_prepare_finalize(routing_tables)
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
@@ -514,7 +437,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
experts = select_nvfp4_gemm_impl(
self.moe,
self.moe_quant_config,
allow_flashinfer=self.allow_flashinfer,
allow_flashinfer=(self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS),
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
@@ -522,19 +445,14 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
if (
self.use_marlin
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
return None
return nvfp4_moe_quant_config(
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
w1_scale=layer.w13_weight_scale,
return make_nvfp4_moe_quant_config(
backend=self.nvfp4_backend,
w13_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w13_scale_2=layer.w13_weight_scale_2,
w2_scale_2=layer.w2_weight_scale_2,
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
def apply(
@@ -546,14 +464,9 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
assert layer.activation == "silu", "Only SiLU activation is supported."
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not layer.enable_eplb
):
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A4MoEMethod` yet."
)
return flashinfer_trtllm_fp4_moe(
layer=layer,
x=x,
@@ -566,79 +479,41 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
e_score_correction_bias=layer.e_score_correction_bias,
)
# Hidden_states in select_experts is only used to extract metadata
if isinstance(x, tuple):
x_routing, _ = x
else:
x_routing = x
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
hidden_states=x_routing,
router_logits=router_logits,
)
if self.use_marlin:
return fused_marlin_moe(
# EPLB path
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
assert layer.enable_eplb
return flashinfer_trtllm_fp4_routed_moe(
layer=layer,
x=x,
topk_ids=topk_ids,
topk_weights=topk_weights,
top_k=layer.top_k,
global_num_experts=layer.global_num_experts,
)
else:
assert self.kernel is not None
return self.kernel(
x,
layer.w13_weight,
layer.w2_weight,
None,
None,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
global_scale1=layer.w13_weight_scale_2,
global_scale2=layer.w2_weight_scale_2,
quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace,
)
# FlashInfer fused experts path
elif self.allow_flashinfer:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4,
)
assert is_valid_flashinfer_cutlass_fused_moe(
x, layer.w13_weight, layer.w2_weight
), "Flashinfer CUTLASS Fused MoE not applicable!"
assert self.moe_quant_config is not None
return flashinfer_cutlass_moe_fp4(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=self.moe_quant_config,
inplace=False, # TODO(shuw): fix later, now output is high prec
inplace=False,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
else:
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
# only (no EP).
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
assert self.moe_quant_config is not None
return cutlass_moe_fp4(
a=x,
w1_fp4=layer.w13_weight,
w2_fp4=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=self.moe_quant_config,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
# TODO(bnell): derive these from arguments
m=x.shape[0],
n=layer.w2_weight.shape[2] * 2,
k=x.shape[1],
e=layer.w13_weight.shape[0],
).to(x.dtype)
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):

View File

@@ -15,9 +15,7 @@ from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
nvfp4_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
@@ -30,6 +28,15 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
FLASHINFER_NVFP4_MOE_BACKENDS,
NvFp4MoeBackend,
convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend,
make_nvfp4_moe_kernel,
make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend,
)
from vllm.model_executor.layers.linear import (
LinearBase,
LinearMethodBase,
@@ -45,16 +52,11 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe,
prepare_static_weights_for_trtllm_fp4_moe,
reorder_w1w3_to_w3w1,
select_nvfp4_gemm_impl,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
get_flashinfer_moe_backend,
is_flashinfer_supporting_global_sf,
select_cutlass_fp8_gemm_impl,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
@@ -69,7 +71,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear,
is_fp4_marlin_supported,
prepare_fp4_layer_for_marlin,
prepare_moe_fp4_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
@@ -89,7 +90,6 @@ from vllm.model_executor.parameter import (
PerTensorScaleParameter,
)
from vllm.model_executor.utils import replace_parameter
from vllm.scalar_type import scalar_types
from vllm.utils.flashinfer import (
flashinfer_scaled_fp4_mm,
has_flashinfer,
@@ -1327,43 +1327,32 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
quant_config: ModelOptNvFp4Config,
layer: FusedMoE,
) -> None:
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (
detect_nvfp4_moe_support, # noqa: E501
)
super().__init__(layer.moe_config)
self.quant_config = quant_config
self.layer = layer
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin
self.marlin_input_dtype = None
self.flashinfer_moe_backend = None
if self.allow_flashinfer:
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
" for ModelOptNvFp4FusedMoE."
self.nvfp4_backend = select_nvfp4_moe_backend()
# TODO: move this type of check into the oracle.
if (
not self.moe.is_act_and_mul
and not self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS
):
raise NotImplementedError(
"Non-gated activations are only supported by FlashInfer "
"CUTLASS NvFP4 MoE backend."
)
elif self.use_marlin:
logger.info_once("Using Marlin for ModelOptNvFp4FusedMoE.")
else:
logger.info_once("Using Cutlass for ModelOptNvFp4FusedMoE.")
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend
)
self.kernel: mk.FusedMoEModularKernel | None = None
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if self.use_marlin or (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM]
if self.nvfp4_backend in UNSUPPORTED:
return None
elif (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
):
elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
# TP case: avoid convert to ModularKernelMethod - to be refactored.
if self.moe.dp_size == 1:
return None
@@ -1385,7 +1374,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
experts = select_nvfp4_gemm_impl(
self.moe,
self.moe_quant_config,
allow_flashinfer=self.allow_flashinfer,
allow_flashinfer=self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
@@ -1405,11 +1394,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if not self.quant_config.is_checkpoint_nvfp4_serialized:
raise ValueError(
"NVFP4 quantization was selected, "
" dynamic quantization is not supported."
)
assert self.quant_config.is_checkpoint_nvfp4_serialized
layer.num_experts = num_experts
layer.params_dtype = params_dtype
@@ -1498,14 +1483,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
self.flashinfer_moe_backend
global_sf_num_experts = (
global_num_experts if self.use_global_sf else num_experts
)
global_scale_num_experts = global_num_experts if use_global_sf else num_experts
w13_input_scale = PerTensorScaleParameter(
data=torch.empty(
global_scale_num_experts,
global_sf_num_experts,
2 if self.moe.is_act_and_mul else 1,
dtype=torch.float32,
),
@@ -1514,32 +1497,17 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = PerTensorScaleParameter(
data=torch.empty(global_scale_num_experts, dtype=torch.float32),
data=torch.empty(global_sf_num_experts, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# GEMM 1 processing
gemm1_weight = layer.w13_weight.data
gemm1_weight_scale = layer.w13_weight_scale.data
"""
Convert NVFP4 MoE weights into kernel format and setup the kernel.
"""
if (
self.allow_flashinfer
and (
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
)
and self.moe.is_act_and_mul
):
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
gemm1_weight, gemm1_weight_scale, dim=-2
)
layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)
# Common processing for w13_weight_scale_2
# Use a single gscale for w13.
if self.moe.is_act_and_mul and not torch.allclose(
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
):
@@ -1547,136 +1515,47 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"w1_weight_scale_2 must match w3_weight_scale_2. "
"Accuracy may be affected."
)
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous()
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
# Common processing for input scales and alphas
use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf(
self.flashinfer_moe_backend
)
if use_global_sf:
# For backends provide by Flashinfer, the input global scales are
# shared across all experts.
w13_input_scale = (
layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts)
)
else:
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
layer.g1_alphas = Parameter(
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
requires_grad=False,
(
w13,
w13_scale,
w13_scale_2,
a13_scale,
w2,
w2_scale,
w2_scale_2,
a2_scale,
) = convert_to_nvfp4_moe_kernel_format(
nvfp4_backend=self.nvfp4_backend,
layer=layer,
w13=layer.w13_weight,
w13_scale=layer.w13_weight_scale,
w13_scale_2=w13_weight_scale_2,
a13_scale=layer.w13_input_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
w2_scale_2=layer.w2_weight_scale_2,
a2_scale=layer.w2_input_scale,
is_act_and_mul=self.moe.is_act_and_mul,
)
# This is for quantization, so we need to invert it.
layer.w13_input_scale_quant = Parameter(
(1 / w13_input_scale).to(torch.float32), requires_grad=False
)
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w13_weight_scale", w13_scale)
replace_parameter(layer, "w13_weight_scale_2", w13_scale_2)
replace_parameter(layer, "w13_input_scale", a13_scale)
replace_parameter(layer, "w2_weight", w2)
replace_parameter(layer, "w2_weight_scale", w2_scale)
replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
replace_parameter(layer, "w2_input_scale", a2_scale)
# GEMM 2 processing
if use_global_sf:
# For backends provide by Flashinfer, the input global scales are
# shared across all experts.
w2_input_scale = (
layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts)
)
else:
w2_input_scale = layer.w2_input_scale
layer.g2_alphas = Parameter(
(w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
requires_grad=False,
)
# This is for quantization, so we need to invert it.
layer.w2_input_scale_quant = Parameter(
(1 / w2_input_scale).to(torch.float32), requires_grad=False
)
# TensorRT-LLM specific processing
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
# Prepare static weights for TRT-LLM kernel
# alternate: prepare_static_weight_layouts_for_trtllm_moe
(
gemm1_weights_fp4_shuffled,
gemm1_scales_fp4_shuffled,
gemm2_weights_fp4_shuffled,
gemm2_scales_fp4_shuffled,
) = prepare_static_weights_for_trtllm_fp4_moe(
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
layer.w2_weight.size(-2), # hidden_size
layer.w13_weight.size(-2) // 2, # intermediate_size
layer.w13_weight.size(0), # num_experts
)
logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
layer.w13_weight = Parameter(
gemm1_weights_fp4_shuffled, requires_grad=False
)
layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False)
layer.w13_weight_scale = Parameter(
gemm1_scales_fp4_shuffled, requires_grad=False
)
layer.w2_weight_scale = Parameter(
gemm2_scales_fp4_shuffled, requires_grad=False
)
# Additional parameter needed for TRT-LLM
layer.g1_scale_c = Parameter(
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
requires_grad=False,
)
elif self.use_marlin:
# Marlin processing
prepare_moe_fp4_layer_for_marlin(layer)
del layer.g1_alphas
del layer.g2_alphas
del layer.w13_input_scale_quant
del layer.w2_input_scale_quant
else:
# Non-TRT-LLM processing (Cutlass or non-flashinfer)
w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
layer.w13_weight_scale = Parameter(
w13_blockscale_swizzled, requires_grad=False
)
w13_weight = layer.w13_weight
intermediate_size_pad = w13_blockscale_swizzled.size(1) - w13_weight.size(1)
if intermediate_size_pad:
# padding gated activations will require to split w1 and w3
# and pad them individually
assert not self.moe.is_act_and_mul, (
"The intermediate size required padding, "
"but padding is not implemented for gated activations"
)
layer.w13_weight = Parameter(
torch.nn.functional.pad(
w13_weight, (0, 0, 0, intermediate_size_pad)
),
requires_grad=False,
)
layer.w2_weight = Parameter(
torch.nn.functional.pad(
layer.w2_weight, (0, intermediate_size_pad // 2, 0, 0)
),
requires_grad=False,
)
layer.w2_weight_scale = Parameter(
torch.nn.functional.pad(
layer.w2_weight_scale, (0, intermediate_size_pad // 16)
),
requires_grad=False,
)
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
layer.w2_weight_scale = Parameter(
w2_blockscale_swizzled, requires_grad=False
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
use_dp = self.moe.dp_size > 1
if self.moe_quant_config is not None and not use_dp:
self.kernel = make_nvfp4_moe_kernel(
backend=self.nvfp4_backend,
quant_config=self.moe_quant_config,
moe_config=self.moe,
)
def prepare_dp_allgather_tensor(
@@ -1688,7 +1567,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"""Optionally prepare extra tensors to carry through DP allgather/EP."""
import flashinfer
a1_gscale = layer.w13_input_scale_quant
assert self.moe_quant_config is not None
a1_gscale = self.moe_quant_config.a1_gscale
hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
hidden_states,
a1_gscale,
@@ -1700,19 +1580,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
if (
self.use_marlin
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
return None
return nvfp4_moe_quant_config(
w1_scale=layer.w13_weight_scale,
return make_nvfp4_moe_quant_config(
backend=self.nvfp4_backend,
w13_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
w13_scale_2=layer.w13_weight_scale_2,
w2_scale_2=layer.w2_weight_scale_2,
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
@property
@@ -1725,18 +1600,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if not self.moe.is_act_and_mul:
assert (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
), (
"Non-gated activations are only supported by the"
" flashinfer CUTLASS backend for modelopt checkpoints"
)
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not layer.enable_eplb
):
return flashinfer_trtllm_fp4_moe(
@@ -1762,10 +1627,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
)
# EPLB path
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
assert layer.enable_eplb
return flashinfer_trtllm_fp4_routed_moe(
layer=layer,
x=x,
@@ -1774,81 +1637,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
top_k=layer.top_k,
global_num_experts=layer.global_num_experts,
)
if self.use_marlin:
return fused_marlin_moe(
else:
assert self.kernel is not None
return self.kernel(
x,
layer.w13_weight,
layer.w2_weight,
None,
None,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
global_scale1=layer.w13_weight_scale_2,
global_scale2=layer.w2_weight_scale_2,
quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype,
)
elif self.allow_flashinfer:
assert self.flashinfer_moe_backend in (
FlashinferMoeBackend.CUTLASS,
FlashinferMoeBackend.CUTEDSL,
)
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4,
)
flashinfer_fn_moe_fp4 = flashinfer_cutlass_moe_fp4
else:
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( # noqa: E501
flashinfer_cutedsl_moe_fp4,
)
flashinfer_fn_moe_fp4 = flashinfer_cutedsl_moe_fp4
assert self.moe_quant_config is not None
return flashinfer_fn_moe_fp4(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=self.moe_quant_config,
inplace=False,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
else:
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
# only (no EP).
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
assert self.moe_quant_config is not None
return cutlass_moe_fp4(
a=x,
w1_fp4=layer.w13_weight,
w2_fp4=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=self.moe_quant_config,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
# TODO: derive from arguments
m=x.shape[0],
n=layer.w2_weight.shape[2] * 2,
k=x.shape[1],
e=layer.w13_weight.shape[0],
)
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod

View File

@@ -2,10 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utility helpers for NVFP4 + FlashInfer fused-MoE path"""
from typing import TYPE_CHECKING
import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
@@ -20,12 +23,23 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
swizzle_blockscale,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
has_flashinfer_cutlass_fused_moe,
)
if TYPE_CHECKING:
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
NvFp4MoeBackend,
)
logger = init_logger(__name__)
__all__ = [
"is_flashinfer_fp4_cutlass_moe_available",
"is_flashinfer_fp4_cutedsl_moe_available",
@@ -273,10 +287,9 @@ def flashinfer_trtllm_fp4_moe(
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# hidden_states is the already quantized
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
layer.a1_gscale,
is_sf_swizzled_layout=False,
)
@@ -369,10 +382,9 @@ def flashinfer_trtllm_fp4_routed_moe(
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# Quantize input to FP4
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
layer.a1_gscale,
is_sf_swizzled_layout=False,
)
@@ -410,3 +422,93 @@ def flashinfer_trtllm_fp4_routed_moe(
)[0]
return out
def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
backend: "NvFp4MoeBackend",
layer: torch.nn.Module,
w13: torch.Tensor,
w13_scale: torch.Tensor,
w13_scale_2: torch.Tensor,
a13_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
w2_scale_2: torch.Tensor,
a2_scale: torch.Tensor,
is_act_and_mul: bool,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
# Delayed import for circular dependency avoidance.
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
NvFp4MoeBackend,
is_global_sf_supported_for_nvfp4_backend,
)
assert backend in [
NvFp4MoeBackend.VLLM_CUTLASS,
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.FLASHINFER_TRTLLM,
NvFp4MoeBackend.FLASHINFER_TRTLLM,
]
# Reorder [w1, w3] to [w3, w1] for FI NVFP4 MoE kernels.
if is_act_and_mul and backend in [
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.FLASHINFER_TRTLLM,
]:
w13, w13_scale = reorder_w1w3_to_w3w1(w13, w13_scale)
# For some FI kernels, the input scales are shared by all experts.
if is_global_sf_supported_for_nvfp4_backend(backend):
num_experts = w13.shape[0]
a13_scale = a13_scale.max().to(torch.float32).expand(num_experts)
a2_scale = a2_scale.max().to(torch.float32).expand(num_experts)
else:
a13_scale = a13_scale.max(dim=1).values.to(torch.float32)
# Shuffle weights and scales for FI TRTLLM NVFP4 MoE kernels.
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe(
w13,
w2,
w13_scale,
w2_scale,
w2.size(-2), # hidden_size
w13.size(-2) // 2, # intermediate_size
w13.size(0), # num_experts
)
# We do not need to make this a parameter, because
# it is not used during the weight (re)-loading process.
layer.g1_scale_c = a13_scale * w13_scale_2 / a2_scale
layer.a1_gscale = 1.0 / a13_scale
layer.g1_alphas = a13_scale * w13_scale_2
layer.g2_alphas = a2_scale * w2_scale_2
else:
# Swizzle the block scales for other FI NVFP4 MoE kernels.
w13_scale = swizzle_blockscale(w13_scale)
# Apply padding if needed.
pad_size = w13_scale.size(1) - w13.size(1)
if pad_size > 0:
if is_act_and_mul:
raise NotImplementedError(
"Intermediate size padding for w1 and w3, for %s "
"NvFp4 backend, but this is not currently supported",
backend.value,
)
w13 = torch.nn.functional.pad(w13, (0, 0, 0, pad_size))
w2 = torch.nn.functional.pad(w2, (0, pad_size // 2, 0, 0))
w2_scale = torch.nn.functional.pad(w2_scale, (0, pad_size // 16))
w2_scale = swizzle_blockscale(w2_scale)
return w13, w13_scale, w13_scale_2, a13_scale, w2, w2_scale, w2_scale_2, a2_scale

View File

@@ -8,6 +8,7 @@ import vllm._custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
USE_FP32_REDUCE_DEFAULT,
get_marlin_input_dtype,
marlin_make_workspace_new,
marlin_permute_bias,
marlin_permute_scales,
@@ -226,6 +227,106 @@ def prepare_fp4_layer_for_marlin(
return
def prepare_nvfp4_moe_layer_for_marlin(
layer: torch.nn.Module,
w13: torch.Tensor,
w13_scale: torch.Tensor,
w13_scale_2: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
w2_scale_2: torch.Tensor,
) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
]:
logger.warning_once(
"Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
input_dtype = get_marlin_input_dtype(prefix="")
if input_dtype is not None and input_dtype.itemsize == 1:
raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.")
GROUP_SIZE = 16
E = layer.num_experts
K = layer.hidden_size
N = layer.intermediate_size_per_partition
device = w13.device
param_dtype = layer.params_dtype
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
# WORKSPACE
layer.workspace = marlin_make_workspace_new(device, 4)
perm = torch.empty(0, dtype=torch.int, device=device)
# WEIGHT
# Repack weights to marlin format
def repack_weight(weight: torch.Tensor, name: str) -> torch.Tensor:
tensor_list = []
if "w13" in name:
size_n, size_k = N * 2, K
else:
size_n, size_k = K, N
assert weight.shape == (E, size_n, size_k // 2)
for i in range(E):
qweight = weight[i].view(torch.int32).T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=qweight,
perm=perm,
size_k=size_k,
size_n=size_n,
num_bits=4,
is_a_8bit=is_a_8bit,
)
tensor_list.append(marlin_qweight)
return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
w13 = repack_weight(w13, "w13")
w2 = repack_weight(w2, "w2")
# WEIGHT SCALES
# Permute scales
def premute_scales(
scales: torch.Tensor, g_scales: torch.Tensor, name: str
) -> tuple[torch.Tensor, torch.Tensor]:
scales = scales.to(param_dtype)
g_scales = g_scales.to(param_dtype)
tensor_list = []
if "w13" in name:
size_n, size_k = N * 2, K
else:
size_n, size_k = K, N
for i in range(E):
scale = scales[i].T
marlin_scales = marlin_permute_scales(
s=scale,
size_k=size_k,
size_n=size_n,
group_size=GROUP_SIZE,
is_a_8bit=is_a_8bit,
)
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
g_scales = nvfp4_marlin_process_global_scale(g_scales)
return scales, g_scales
w13_scale, w13_scale_2 = premute_scales(w13_scale, w13_scale_2, "w13")
w2_scale, w2_scale_2 = premute_scales(w2_scale, w2_scale_2, "w2")
return w13, w13_scale, w13_scale_2, w2, w2_scale, w2_scale_2
def prepare_moe_fp4_layer_for_marlin(
layer: torch.nn.Module, input_dtype: torch.dtype | None = None
) -> None: