[Feat] Support non-gated MoE with Marlin, NVFP4 CUTLASS, FP8, INT8, compressed-tensors (#32257)

Signed-off-by: Tomer Natan <tbarnatan@computelab-frontend-8.nvidia.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Tomer Natan <tbarnatan@computelab-frontend-8.nvidia.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Tomer Natan <tbarnatan@ipp1-1429.ipp1a1.colossus.nvidia.com>
This commit is contained in:
TomerBN-Nvidia
2026-01-16 02:15:05 +02:00
committed by GitHub
parent aca5c51487
commit c277fbdf31
17 changed files with 226 additions and 127 deletions

View File

@@ -526,7 +526,7 @@ def test_run_cutlass_moe_fp8(
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
activation = "silu"
a1q, a1q_scale = moe_kernel_quantize_input(
mt.a, mt.a_scale, torch.float8_e4m3fn, per_act_token
)

View File

@@ -1079,6 +1079,86 @@ def test_fused_marlin_moe_with_bias(m):
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
@pytest.mark.parametrize("m", [1, 64, 256])
@pytest.mark.parametrize("n,k", [(1024, 1024), (2048, 2048)])
@pytest.mark.parametrize("e,topk", [(8, 2), (64, 4)])
def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int):
"""Test Marlin MoE with non-gated activation (relu2_no_mul).
Non-gated activations like relu2 don't have the gate-up projection pattern,
so w1 has shape (e, n, k) instead of (e, 2*n, k).
"""
torch.cuda.manual_seed(42)
group_size = 16 # NVFP4 group size
is_k_full = True
quant_type = scalar_types.float4_e2m1f
dtype = torch.bfloat16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
# Non-gated: w1 shape is (e, n, k) not (e, 2*n, k)
w1 = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
w1_data = MarlinMoEWeightData.make(
w=w1,
quant_type=quant_type,
group_size=group_size,
act_order=False,
)
w2_data = MarlinMoEWeightData.make(
w=w2,
quant_type=quant_type,
group_size=group_size,
act_order=False,
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(
a,
w1_data.w_ref,
w2_data.w_ref,
score,
topk,
activation="relu2",
)
marlin_output = fused_marlin_moe(
a,
w1_data.qweight,
w2_data.qweight,
None, # bias1
None, # bias2
w1_data.scales,
w2_data.scales,
score,
topk_weights,
topk_ids,
global_num_experts=e,
expert_map=None,
global_scale1=w1_data.global_scale,
global_scale2=w2_data.global_scale,
g_idx1=w1_data.g_idx,
g_idx2=w2_data.g_idx,
sort_indices1=w1_data.sort_indices,
sort_indices2=w2_data.sort_indices,
w1_zeros=w1_data.zeros,
w2_zeros=w2_data.zeros,
quant_type_id=quant_type.id,
is_k_full=is_k_full,
activation="relu2_no_mul",
)
torch.testing.assert_close(marlin_output, torch_output, atol=1e-1, rtol=0)
@pytest.mark.parametrize("ep_size", [1, 2])
def test_moe_align_block_size_opcheck(ep_size):
num_experts = 4

View File

@@ -1451,6 +1451,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "flashinfer-cudnn": use flashinfer cudnn GEMM backend
# - "flashinfer-trtllm": use flashinfer trtllm GEMM backend
# - "flashinfer-cutlass": use flashinfer cutlass GEMM backend
# - "marlin": use marlin GEMM backend (for GPUs without native FP4 support)
# - <none>: automatically pick an available backend
"VLLM_NVFP4_GEMM_BACKEND": env_with_choices(
"VLLM_NVFP4_GEMM_BACKEND",
@@ -1460,6 +1461,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
"flashinfer-trtllm",
"flashinfer-cutlass",
"cutlass",
"marlin",
],
),
# Controls garbage collection during CUDA graph capture.

View File

@@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""CUTLASS based Fused MoE kernels."""
from collections.abc import Callable
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@@ -21,7 +19,10 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
apply_moe_activation,
)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@@ -33,7 +34,7 @@ def run_cutlass_moe_fp8(
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation_callable: Callable,
activation: str,
global_num_experts: int,
expert_map: torch.Tensor | None,
w1_scale: torch.Tensor | None,
@@ -55,6 +56,7 @@ def run_cutlass_moe_fp8(
):
a1q = hidden_states
assert not activation.endswith("_no_mul"), "Only gated activation is supported"
assert w1_scale is not None
assert w2_scale is not None
assert w1.dtype == torch.float8_e4m3fn
@@ -198,7 +200,7 @@ def run_cutlass_moe_fp8(
per_out_ch,
)
activation_callable(act_out, mm1_out)
apply_moe_activation(activation, act_out, mm1_out)
a2q, a2q_scale = ops.scaled_fp8_quant(
act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out
@@ -288,8 +290,6 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
if expert_tokens_meta is not None:
expert_num_tokens = expert_tokens_meta.expert_num_tokens
activation_callable = lambda o, i: self.activation(activation, o, i)
use_batched_format = (
self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts
)
@@ -301,7 +301,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
w1,
w2,
topk_ids,
activation_callable,
activation,
global_num_experts,
expert_map,
self.w1_scale,
@@ -436,6 +436,7 @@ def run_cutlass_moe_fp4(
w2_alphas: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
m: int,
@@ -544,8 +545,7 @@ def run_cutlass_moe_fp4(
num_topk,
)
c1 = _resize_cache(workspace13, (m * topk, n * 2))
# Note: c2 workspace is no longer needed since SiLU is fused with quantization.
# c3 reuses workspace13 after c1 is consumed.
c2 = _resize_cache(workspace2, (m * topk, n))
c3 = _resize_cache(workspace13, (m * topk, k))
ops.cutlass_fp4_moe_mm(
c1,
@@ -559,10 +559,18 @@ def run_cutlass_moe_fp4(
blockscale_offsets[:-1],
)
del rep_a_fp4, rep_a_blockscale
# Fused SiLU+Mul+NVFP4 quantization
int_fp4, int_blockscale = ops.silu_and_mul_scaled_fp4_experts_quant(
c1, a2_gscale, expert_offsets, blockscale_offsets, num_topk
)
if activation == "silu":
# Fused SiLU+Mul+NVFP4 quantization
# Note: c2 workspace is no longer needed since SiLU is fused with quantization.
# c3 reuses workspace13 after c1 is consumed.
int_fp4, int_blockscale = ops.silu_and_mul_scaled_fp4_experts_quant(
c1, a2_gscale, expert_offsets, blockscale_offsets, num_topk
)
else:
apply_moe_activation(activation, c2, c1)
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk
)
ops.cutlass_fp4_moe_mm(
c3,
@@ -693,6 +701,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
w2_alphas=self.g2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
workspace13=workspace13,
workspace2=workspace2,
m=m,
@@ -711,7 +720,7 @@ def run_cutlass_moe_w4a8_fp8(
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation_callable: Callable,
activation: str,
global_num_experts: int,
expert_map: torch.Tensor | None,
w1_scale: torch.Tensor | None,
@@ -815,7 +824,7 @@ def run_cutlass_moe_w4a8_fp8(
s_strides1,
)
activation_callable(act_out, mm1_out)
apply_moe_activation(activation, act_out, mm1_out)
a2q, a2q_scale = ops.scaled_fp8_quant(
act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out
@@ -936,7 +945,6 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
expert_num_tokens = None
activation_callable = lambda o, i: self.activation(activation, o, i)
use_batched_format = (
self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts
@@ -951,7 +959,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
w1,
w2,
topk_ids,
activation_callable,
activation,
global_num_experts,
expert_map,
self.w1_scale,

View File

@@ -17,7 +17,11 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache, disable_inplace
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
apply_moe_activation,
disable_inplace,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new,
marlin_moe_intermediate_size,
@@ -27,21 +31,6 @@ from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
def default_activation_func(
activation: str, output: torch.Tensor, input: torch.Tensor
) -> None:
if activation == "silu":
torch.ops._C.silu_and_mul(output, input)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(output, input)
else:
raise ValueError(
f"Unsupported activation: {activation}. "
"Only silu and swigluoai activations are supported."
)
def _fused_marlin_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
@@ -62,7 +51,7 @@ def _fused_marlin_moe(
activation: str = "silu",
activation_func: Callable[
[str, torch.Tensor, torch.Tensor], None
] = default_activation_func,
] = apply_moe_activation,
input_global_scale1: torch.Tensor | None = None,
input_global_scale2: torch.Tensor | None = None,
global_scale1: torch.Tensor | None = None,
@@ -83,13 +72,13 @@ def _fused_marlin_moe(
assert hidden_states.ndim == 2
M, K = hidden_states.size()
N = marlin_moe_intermediate_size(w1, w2)
w13_num_shards = 1 if "no_mul" in activation else 2
if workspace is None:
workspace = marlin_make_workspace_new(hidden_states.device, 4)
if intermediate_cache13 is None:
intermediate_cache13 = torch.empty(
(M * num_topk * max(2 * N, K),),
(M * num_topk * max(w13_num_shards * N, K),),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
@@ -101,7 +90,9 @@ def _fused_marlin_moe(
dtype=hidden_states.dtype,
)
intermediate_cache1 = _resize_cache(intermediate_cache13, (M * num_topk, 2 * N))
intermediate_cache1 = _resize_cache(
intermediate_cache13, (M * num_topk, w13_num_shards * N)
)
intermediate_cache3 = _resize_cache(intermediate_cache13, (M * num_topk, K))
@@ -137,16 +128,17 @@ def _fused_marlin_moe(
mul_topk_weights=apply_router_weight_on_input,
b_q_type=quant_type,
size_m=M,
size_n=2 * N,
size_n=w13_num_shards * N,
size_k=K,
is_k_full=is_k_full,
use_atomic_add=False,
use_fp32_reduce=True,
is_zp_float=False,
)
activation_func(
activation, intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
activation,
intermediate_cache2,
intermediate_cache1.view(-1, w13_num_shards * N),
)
if output is None:
@@ -216,7 +208,7 @@ def fused_marlin_moe(
activation: str = "silu",
activation_func: Callable[
[str, torch.Tensor, torch.Tensor], None
] = default_activation_func,
] = apply_moe_activation,
moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None,
expert_map: torch.Tensor | None = None,
input_global_scale1: torch.Tensor | None = None,

View File

@@ -619,30 +619,11 @@ class FusedMoE(CustomOp):
# for heuristic purposes, so it must be initialized first.
self.quant_method: FusedMoEMethodBase = _get_quant_method()
if not self.moe_config.is_act_and_mul:
# Avoid circular import
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptFp8MoEMethod,
ModelOptNvFp4FusedMoE,
if not self.moe_config.is_act_and_mul and not current_platform.is_cuda():
raise NotImplementedError(
"is_act_and_mul=False is supported only for CUDA for now"
)
if not isinstance(
self.quant_method,
(
UnquantizedFusedMoEMethod,
ModelOptFp8MoEMethod,
ModelOptNvFp4FusedMoE,
),
):
raise NotImplementedError(
"is_act_and_mul=False is supported only for unquantized "
", ModelOpt FP8, and ModelOpt NvFp4 checkpoints"
)
if not current_platform.is_cuda():
raise NotImplementedError(
"is_act_and_mul=False is supported only for CUDA for now"
)
if self.enable_eplb and not self.quant_method.supports_eplb:
# TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not

View File

@@ -52,7 +52,7 @@ def select_fp8_moe_backend(
block_quant: bool,
tp_size: int,
with_lora_support: bool,
is_act_and_mul: bool = True,
is_act_and_mul: bool,
allow_vllm_cutlass: bool = False,
) -> Fp8MoeBackend:
"""
@@ -128,7 +128,7 @@ def select_fp8_moe_backend(
scope="local",
)
if use_deep_gemm and moe_use_deep_gemm and block_quant:
if use_deep_gemm and moe_use_deep_gemm and block_quant and is_act_and_mul:
if not has_deep_gemm():
logger.warning_once(
"DeepGEMM backend requested but not available.", scope="local"
@@ -141,7 +141,12 @@ def select_fp8_moe_backend(
logger.info_once(_make_log_backend("ROCm AITER"), scope="local")
return Fp8MoeBackend.AITER
if allow_vllm_cutlass and not block_quant and cutlass_group_gemm_supported():
if (
allow_vllm_cutlass
and not block_quant
and cutlass_group_gemm_supported()
and is_act_and_mul
):
logger.info_once(_make_log_backend("vLLM CUTLASS"), scope="local")
return Fp8MoeBackend.VLLM_CUTLASS

View File

@@ -178,6 +178,7 @@ def convert_to_nvfp4_moe_kernel_format(
w2=w2,
w2_scale=w2_scale,
w2_scale_2=w2_scale_2,
is_act_and_mul=is_act_and_mul,
)
else:
raise ValueError(f"Unknown NvFp4 backend for MoE: {nvfp4_backend}")

View File

@@ -367,7 +367,8 @@ def apply_moe_activation(
elif activation == GELU_NO_MUL:
output.copy_(F.gelu(input))
elif activation == RELU2_NO_MUL:
torch.square(F.relu(input), out=output)
F.relu(input, inplace=True)
torch.square(input, out=output)
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")

View File

@@ -764,8 +764,6 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,

View File

@@ -370,12 +370,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer_name: str | None = None,
use_marlin: bool = False,
):
if not moe.is_act_and_mul:
raise ValueError(
"CompressedTensorsW4A4Nvfp4MoEMethod does not yet "
"support non gated MoE models."
)
super().__init__(moe)
self.group_size = 16
if use_marlin:
@@ -388,6 +382,16 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
)
else:
self.nvfp4_backend = select_nvfp4_moe_backend()
# TODO: move this type of check into the oracle.
if not self.moe.is_act_and_mul and self.nvfp4_backend not in [
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.MARLIN,
]:
raise NotImplementedError(
"Non-gated activations are only supported by FlashInfer "
f"CUTLASS and Marlin NvFP4 MoE backends, not {self.nvfp4_backend}."
)
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend
)
@@ -404,11 +408,12 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
):
layer.num_experts = num_experts
layer.params_dtype = params_dtype
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
w13_num_shards * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // 2,
requires_grad=False,
@@ -436,7 +441,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
w13_weight_scale = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
w13_num_shards * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // self.group_size,
dtype=torch.float8_e4m3fn,
@@ -467,7 +472,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
# Weight Global Scales
w13_weight_scale_2 = torch.nn.Parameter(
torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False
torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2)
extra_weight_attrs.update(
@@ -486,7 +492,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
# Input Global Scales
w13_input_scale = torch.nn.Parameter(
torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False
torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_input_global_scale", w13_input_scale)
extra_weight_attrs.update(
@@ -640,6 +647,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
x=x,
router_logits=router_logits,
top_k=layer.top_k,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
@@ -666,6 +674,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
topk_ids=topk_ids,
topk_weights=topk_weights,
top_k=layer.top_k,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
)
else:
@@ -722,6 +731,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
block_quant=self.block_quant,
tp_size=moe.tp_size,
with_lora_support=moe.is_lora_enabled,
is_act_and_mul=moe.is_act_and_mul,
# TODO(rob): enable selecting this externally.
allow_vllm_cutlass=True,
)
@@ -760,6 +770,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.weight_block_size = None
params_dtype = torch.float8_e4m3fn
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
if self.block_quant:
assert self.weight_block_size is not None
@@ -791,7 +802,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
w13_num_shards * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype,
),
@@ -814,10 +825,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# WEIGHT_SCALES
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
# Allocate 2 scales for w1 and w3 respectively.
# They are combined to a single scale after weight loading.
# For gated MoE, allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
# For non-gated MoE, allocate 1 scale for w13.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
torch.ones(num_experts, w13_num_shards, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(
@@ -835,7 +848,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * intermediate_size_per_partition,
w13_num_shards * intermediate_size_per_partition,
1,
dtype=torch.float32,
),
@@ -858,7 +871,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
w13_num_shards
* ((intermediate_size_per_partition + block_n - 1) // block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
@@ -930,11 +944,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# Per-tensor kernels use a single scale, for W13, but on disk there
# is a separate scale for W1 and W3. Requantize with the max scale.
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
process_fp8_weight_tensor_strategy_moe(
w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
w13,
w13_scale,
shard_size=layer.intermediate_size_per_partition,
num_experts=layer.num_local_experts,
is_act_and_mul=self.moe.is_act_and_mul,
)
w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
@@ -1166,12 +1181,13 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
**extra_weight_attrs,
):
params_dtype = torch.int8
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
w13_num_shards * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype,
),
@@ -1196,7 +1212,10 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
num_experts,
w13_num_shards * intermediate_size_per_partition,
1,
dtype=torch.float32,
),
requires_grad=False,
)
@@ -1296,6 +1315,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
**extra_weight_attrs,
):
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
@@ -1307,7 +1327,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
torch.empty(
num_experts,
hidden_size // self.packed_factor,
2 * intermediate_size_per_partition,
w13_num_shards * intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
@@ -1352,7 +1372,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
torch.ones(
num_experts,
num_groups_w13,
2 * intermediate_size_per_partition,
w13_num_shards * intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
@@ -1600,10 +1620,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert layer.activation == "silu", (
f"{layer.activation} not supported for Marlin MoE."
)
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
@@ -1625,6 +1641,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
activation=layer.activation,
expert_map=layer.expert_map,
g_idx1=layer.w13_weight_g_idx,
g_idx2=layer.w2_weight_g_idx,
@@ -1675,11 +1692,12 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
extra_weight_attrs.update(
{"is_transposed": True, "quant_method": self.strategy}
)
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size // self.packed_factor,
2 * intermediate_size_per_partition,
w13_num_shards * intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
@@ -1712,7 +1730,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
torch.ones(
num_experts,
num_groups_w13,
2 * intermediate_size_per_partition,
w13_num_shards * intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,

View File

@@ -637,6 +637,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
block_quant=self.block_quant,
tp_size=layer.moe_parallel_config.tp_size,
with_lora_support=self.moe.is_lora_enabled,
is_act_and_mul=self.moe.is_act_and_mul,
)
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:

View File

@@ -900,8 +900,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,

View File

@@ -733,6 +733,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
block_quant=False,
tp_size=moe_config.moe_parallel_config.tp_size,
with_lora_support=self.moe.is_lora_enabled,
is_act_and_mul=self.moe.is_act_and_mul,
)
self.kernel: mk.FusedMoEModularKernel | None = None
@@ -789,15 +790,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
weight_loader = extra_weight_attrs.get("weight_loader")
if self.moe.is_act_and_mul:
w13_up_dim = 2 * intermediate_size_per_partition
else:
w13_up_dim = intermediate_size_per_partition
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
w13_weight = ModelWeightParameter(
data=torch.empty(
num_experts,
w13_up_dim,
w13_num_shards * intermediate_size_per_partition,
hidden_size,
dtype=weight_dtype,
),
@@ -826,7 +824,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
# For non-gated MoE, allocate 1 scale for w13.
w13_weight_scale = PerTensorScaleParameter(
data=torch.full(
(num_experts, 2 if self.moe.is_act_and_mul else 1),
(num_experts, w13_num_shards),
1.0,
dtype=torch.float32,
),
@@ -1132,6 +1130,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
self.backend = "cutlass"
assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
elif envs.VLLM_NVFP4_GEMM_BACKEND == "marlin":
self.backend = "marlin"
assert is_fp4_marlin_supported(), f"Marlin is required for {self.backend}"
if self.backend == "none":
raise ValueError(
@@ -1337,13 +1338,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self.quant_config = quant_config
self.nvfp4_backend = select_nvfp4_moe_backend()
# TODO: move this type of check into the oracle.
if (
not self.moe.is_act_and_mul
and not self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS
):
if not self.moe.is_act_and_mul and self.nvfp4_backend not in [
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.MARLIN,
]:
raise NotImplementedError(
"Non-gated activations are only supported by FlashInfer "
"CUTLASS NvFP4 MoE backend."
f"CUTLASS and Marlin NvFP4 MoE backends, not {self.nvfp4_backend}."
)
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
@@ -1409,11 +1410,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
weight_scale_dtype = torch.float8_e4m3fn
weight_loader = extra_weight_attrs.get("weight_loader")
global_num_experts = extra_weight_attrs.get("global_num_experts")
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
# GEMM 1
w13_weight = ModelWeightParameter(
data=torch.empty(
num_experts,
(2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
w13_num_shards * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // 2,
dtype=weight_dtype,
@@ -1442,7 +1444,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
w13_weight_scale = ModelWeightParameter(
data=torch.empty(
num_experts,
(2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
w13_num_shards * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // self.quant_config.group_size,
dtype=weight_scale_dtype,
@@ -1472,9 +1474,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
)
w13_weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(
num_experts, 2 if self.moe.is_act_and_mul else 1, dtype=torch.float32
),
data=torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
@@ -1495,7 +1495,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
w13_input_scale = PerTensorScaleParameter(
data=torch.empty(
global_sf_num_experts,
2 if self.moe.is_act_and_mul else 1,
w13_num_shards,
dtype=torch.float32,
),
weight_loader=weight_loader,
@@ -1616,6 +1616,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
x=x,
router_logits=router_logits,
top_k=layer.top_k,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
@@ -1642,6 +1643,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
topk_ids=topk_ids,
topk_weights=topk_weights,
top_k=layer.top_k,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
)
else:

View File

@@ -255,6 +255,7 @@ def flashinfer_trtllm_fp4_moe(
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
router_logits: torch.Tensor,
top_k: int,
activation: str,
global_num_experts: int,
num_expert_group: int | None,
topk_group: int | None,
@@ -269,6 +270,7 @@ def flashinfer_trtllm_fp4_moe(
x: Input tensor
router_logits: Router logits for expert selection
top_k: Number of experts to select per token
activation: Activation function to use
global_num_experts: Total number of experts across all ranks
num_expert_group: Number of expert groups (for grouped routing)
topk_group: Top-k within each group
@@ -282,6 +284,12 @@ def flashinfer_trtllm_fp4_moe(
from vllm.model_executor.models.llama4 import Llama4MoE
# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2404
assert activation == "silu", (
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 MoE. "
f"{activation} found instead."
)
# Quantize input to FP4
if isinstance(x, tuple):
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
@@ -352,6 +360,7 @@ def flashinfer_trtllm_fp4_routed_moe(
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
top_k: int,
activation: str,
global_num_experts: int,
) -> torch.Tensor:
"""
@@ -364,6 +373,7 @@ def flashinfer_trtllm_fp4_routed_moe(
x: Input tensor
topk_ids: Ids of selected experts
top_k: Number of experts to select per token
activation: Activation function to use
global_num_experts: Total number of experts across all ranks
Returns:
@@ -371,6 +381,12 @@ def flashinfer_trtllm_fp4_routed_moe(
"""
import flashinfer
# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2535
assert activation == "silu", (
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 Routed MoE. "
f"{activation} found instead."
)
# Pack top k ids and expert weights into a single int32 tensor, as
# required by TRT-LLM
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(

View File

@@ -233,8 +233,6 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
intermediate_size_per_partition = layer.intermediate_size_per_partition
# apply_router_weight_on_input is not supported for moe marlin
supports_router_weight = not layer.apply_router_weight_on_input
# moe marlin requires the activation to be silu
supports_activation = layer.activation == "silu"
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
@@ -244,12 +242,7 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
and intermediate_size_per_partition % max(64, group_size) == 0
)
supports_group_size = group_size in [-1, 32, 64, 128]
return (
supports_shape
and supports_group_size
and supports_router_weight
and supports_activation
)
return supports_shape and supports_group_size and supports_router_weight
def marlin_moe_intermediate_size(w1_packed: torch.Tensor, w2_packed: torch.Tensor):

View File

@@ -235,6 +235,7 @@ def prepare_nvfp4_moe_layer_for_marlin(
w2: torch.Tensor,
w2_scale: torch.Tensor,
w2_scale_2: torch.Tensor,
is_act_and_mul: bool,
) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
]:
@@ -266,8 +267,9 @@ def prepare_nvfp4_moe_layer_for_marlin(
# Repack weights to marlin format
def repack_weight(weight: torch.Tensor, name: str) -> torch.Tensor:
tensor_list = []
num_shards = 2 if is_act_and_mul else 1
if "w13" in name:
size_n, size_k = N * 2, K
size_n, size_k = N * num_shards, K
else:
size_n, size_k = K, N
@@ -300,8 +302,9 @@ def prepare_nvfp4_moe_layer_for_marlin(
g_scales = g_scales.to(param_dtype)
tensor_list = []
num_shards = 2 if is_act_and_mul else 1
if "w13" in name:
size_n, size_k = N * 2, K
size_n, size_k = N * num_shards, K
else:
size_n, size_k = K, N