[Feat] Support non-gated MoE with Marlin, NVFP4 CUTLASS, FP8, INT8, compressed-tensors (#32257)
Signed-off-by: Tomer Natan <tbarnatan@computelab-frontend-8.nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Tomer Natan <tbarnatan@computelab-frontend-8.nvidia.com> Co-authored-by: mgoin <mgoin64@gmail.com> Co-authored-by: Tomer Natan <tbarnatan@ipp1-1429.ipp1a1.colossus.nvidia.com>
This commit is contained in:
@@ -526,7 +526,7 @@ def test_run_cutlass_moe_fp8(
|
||||
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
|
||||
activation = "silu"
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
mt.a, mt.a_scale, torch.float8_e4m3fn, per_act_token
|
||||
)
|
||||
|
||||
@@ -1079,6 +1079,86 @@ def test_fused_marlin_moe_with_bias(m):
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
@pytest.mark.parametrize("m", [1, 64, 256])
|
||||
@pytest.mark.parametrize("n,k", [(1024, 1024), (2048, 2048)])
|
||||
@pytest.mark.parametrize("e,topk", [(8, 2), (64, 4)])
|
||||
def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int):
|
||||
"""Test Marlin MoE with non-gated activation (relu2_no_mul).
|
||||
|
||||
Non-gated activations like relu2 don't have the gate-up projection pattern,
|
||||
so w1 has shape (e, n, k) instead of (e, 2*n, k).
|
||||
"""
|
||||
torch.cuda.manual_seed(42)
|
||||
|
||||
group_size = 16 # NVFP4 group size
|
||||
is_k_full = True
|
||||
quant_type = scalar_types.float4_e2m1f
|
||||
dtype = torch.bfloat16
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
# Non-gated: w1 shape is (e, n, k) not (e, 2*n, k)
|
||||
w1 = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
w1_data = MarlinMoEWeightData.make(
|
||||
w=w1,
|
||||
quant_type=quant_type,
|
||||
group_size=group_size,
|
||||
act_order=False,
|
||||
)
|
||||
|
||||
w2_data = MarlinMoEWeightData.make(
|
||||
w=w2,
|
||||
quant_type=quant_type,
|
||||
group_size=group_size,
|
||||
act_order=False,
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_moe(
|
||||
a,
|
||||
w1_data.w_ref,
|
||||
w2_data.w_ref,
|
||||
score,
|
||||
topk,
|
||||
activation="relu2",
|
||||
)
|
||||
|
||||
marlin_output = fused_marlin_moe(
|
||||
a,
|
||||
w1_data.qweight,
|
||||
w2_data.qweight,
|
||||
None, # bias1
|
||||
None, # bias2
|
||||
w1_data.scales,
|
||||
w2_data.scales,
|
||||
score,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
global_scale1=w1_data.global_scale,
|
||||
global_scale2=w2_data.global_scale,
|
||||
g_idx1=w1_data.g_idx,
|
||||
g_idx2=w2_data.g_idx,
|
||||
sort_indices1=w1_data.sort_indices,
|
||||
sort_indices2=w2_data.sort_indices,
|
||||
w1_zeros=w1_data.zeros,
|
||||
w2_zeros=w2_data.zeros,
|
||||
quant_type_id=quant_type.id,
|
||||
is_k_full=is_k_full,
|
||||
activation="relu2_no_mul",
|
||||
)
|
||||
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=1e-1, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ep_size", [1, 2])
|
||||
def test_moe_align_block_size_opcheck(ep_size):
|
||||
num_experts = 4
|
||||
|
||||
@@ -1451,6 +1451,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# - "flashinfer-cudnn": use flashinfer cudnn GEMM backend
|
||||
# - "flashinfer-trtllm": use flashinfer trtllm GEMM backend
|
||||
# - "flashinfer-cutlass": use flashinfer cutlass GEMM backend
|
||||
# - "marlin": use marlin GEMM backend (for GPUs without native FP4 support)
|
||||
# - <none>: automatically pick an available backend
|
||||
"VLLM_NVFP4_GEMM_BACKEND": env_with_choices(
|
||||
"VLLM_NVFP4_GEMM_BACKEND",
|
||||
@@ -1460,6 +1461,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"flashinfer-trtllm",
|
||||
"flashinfer-cutlass",
|
||||
"cutlass",
|
||||
"marlin",
|
||||
],
|
||||
),
|
||||
# Controls garbage collection during CUDA graph capture.
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""CUTLASS based Fused MoE kernels."""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
@@ -21,7 +19,10 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate,
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache,
|
||||
apply_moe_activation,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -33,7 +34,7 @@ def run_cutlass_moe_fp8(
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation_callable: Callable,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
w1_scale: torch.Tensor | None,
|
||||
@@ -55,6 +56,7 @@ def run_cutlass_moe_fp8(
|
||||
):
|
||||
a1q = hidden_states
|
||||
|
||||
assert not activation.endswith("_no_mul"), "Only gated activation is supported"
|
||||
assert w1_scale is not None
|
||||
assert w2_scale is not None
|
||||
assert w1.dtype == torch.float8_e4m3fn
|
||||
@@ -198,7 +200,7 @@ def run_cutlass_moe_fp8(
|
||||
per_out_ch,
|
||||
)
|
||||
|
||||
activation_callable(act_out, mm1_out)
|
||||
apply_moe_activation(activation, act_out, mm1_out)
|
||||
|
||||
a2q, a2q_scale = ops.scaled_fp8_quant(
|
||||
act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out
|
||||
@@ -288,8 +290,6 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
if expert_tokens_meta is not None:
|
||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||
|
||||
activation_callable = lambda o, i: self.activation(activation, o, i)
|
||||
|
||||
use_batched_format = (
|
||||
self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts
|
||||
)
|
||||
@@ -301,7 +301,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
w1,
|
||||
w2,
|
||||
topk_ids,
|
||||
activation_callable,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
self.w1_scale,
|
||||
@@ -436,6 +436,7 @@ def run_cutlass_moe_fp4(
|
||||
w2_alphas: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
m: int,
|
||||
@@ -544,8 +545,7 @@ def run_cutlass_moe_fp4(
|
||||
num_topk,
|
||||
)
|
||||
c1 = _resize_cache(workspace13, (m * topk, n * 2))
|
||||
# Note: c2 workspace is no longer needed since SiLU is fused with quantization.
|
||||
# c3 reuses workspace13 after c1 is consumed.
|
||||
c2 = _resize_cache(workspace2, (m * topk, n))
|
||||
c3 = _resize_cache(workspace13, (m * topk, k))
|
||||
ops.cutlass_fp4_moe_mm(
|
||||
c1,
|
||||
@@ -559,10 +559,18 @@ def run_cutlass_moe_fp4(
|
||||
blockscale_offsets[:-1],
|
||||
)
|
||||
del rep_a_fp4, rep_a_blockscale
|
||||
# Fused SiLU+Mul+NVFP4 quantization
|
||||
int_fp4, int_blockscale = ops.silu_and_mul_scaled_fp4_experts_quant(
|
||||
c1, a2_gscale, expert_offsets, blockscale_offsets, num_topk
|
||||
)
|
||||
if activation == "silu":
|
||||
# Fused SiLU+Mul+NVFP4 quantization
|
||||
# Note: c2 workspace is no longer needed since SiLU is fused with quantization.
|
||||
# c3 reuses workspace13 after c1 is consumed.
|
||||
int_fp4, int_blockscale = ops.silu_and_mul_scaled_fp4_experts_quant(
|
||||
c1, a2_gscale, expert_offsets, blockscale_offsets, num_topk
|
||||
)
|
||||
else:
|
||||
apply_moe_activation(activation, c2, c1)
|
||||
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
|
||||
c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk
|
||||
)
|
||||
|
||||
ops.cutlass_fp4_moe_mm(
|
||||
c3,
|
||||
@@ -693,6 +701,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
w2_alphas=self.g2_alphas,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
m=m,
|
||||
@@ -711,7 +720,7 @@ def run_cutlass_moe_w4a8_fp8(
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation_callable: Callable,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
w1_scale: torch.Tensor | None,
|
||||
@@ -815,7 +824,7 @@ def run_cutlass_moe_w4a8_fp8(
|
||||
s_strides1,
|
||||
)
|
||||
|
||||
activation_callable(act_out, mm1_out)
|
||||
apply_moe_activation(activation, act_out, mm1_out)
|
||||
|
||||
a2q, a2q_scale = ops.scaled_fp8_quant(
|
||||
act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out
|
||||
@@ -936,7 +945,6 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||
|
||||
expert_num_tokens = None
|
||||
activation_callable = lambda o, i: self.activation(activation, o, i)
|
||||
|
||||
use_batched_format = (
|
||||
self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts
|
||||
@@ -951,7 +959,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
w1,
|
||||
w2,
|
||||
topk_ids,
|
||||
activation_callable,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
self.w1_scale,
|
||||
|
||||
@@ -17,7 +17,11 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate,
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache, disable_inplace
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache,
|
||||
apply_moe_activation,
|
||||
disable_inplace,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_make_workspace_new,
|
||||
marlin_moe_intermediate_size,
|
||||
@@ -27,21 +31,6 @@ from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
|
||||
def default_activation_func(
|
||||
activation: str, output: torch.Tensor, input: torch.Tensor
|
||||
) -> None:
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(output, input)
|
||||
elif activation == "swigluoai":
|
||||
# alpha = 1.702, limit = 7.0
|
||||
torch.ops._C.swigluoai_and_mul(output, input)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {activation}. "
|
||||
"Only silu and swigluoai activations are supported."
|
||||
)
|
||||
|
||||
|
||||
def _fused_marlin_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -62,7 +51,7 @@ def _fused_marlin_moe(
|
||||
activation: str = "silu",
|
||||
activation_func: Callable[
|
||||
[str, torch.Tensor, torch.Tensor], None
|
||||
] = default_activation_func,
|
||||
] = apply_moe_activation,
|
||||
input_global_scale1: torch.Tensor | None = None,
|
||||
input_global_scale2: torch.Tensor | None = None,
|
||||
global_scale1: torch.Tensor | None = None,
|
||||
@@ -83,13 +72,13 @@ def _fused_marlin_moe(
|
||||
assert hidden_states.ndim == 2
|
||||
M, K = hidden_states.size()
|
||||
N = marlin_moe_intermediate_size(w1, w2)
|
||||
|
||||
w13_num_shards = 1 if "no_mul" in activation else 2
|
||||
if workspace is None:
|
||||
workspace = marlin_make_workspace_new(hidden_states.device, 4)
|
||||
|
||||
if intermediate_cache13 is None:
|
||||
intermediate_cache13 = torch.empty(
|
||||
(M * num_topk * max(2 * N, K),),
|
||||
(M * num_topk * max(w13_num_shards * N, K),),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
@@ -101,7 +90,9 @@ def _fused_marlin_moe(
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
intermediate_cache1 = _resize_cache(intermediate_cache13, (M * num_topk, 2 * N))
|
||||
intermediate_cache1 = _resize_cache(
|
||||
intermediate_cache13, (M * num_topk, w13_num_shards * N)
|
||||
)
|
||||
|
||||
intermediate_cache3 = _resize_cache(intermediate_cache13, (M * num_topk, K))
|
||||
|
||||
@@ -137,16 +128,17 @@ def _fused_marlin_moe(
|
||||
mul_topk_weights=apply_router_weight_on_input,
|
||||
b_q_type=quant_type,
|
||||
size_m=M,
|
||||
size_n=2 * N,
|
||||
size_n=w13_num_shards * N,
|
||||
size_k=K,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=False,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
activation_func(
|
||||
activation, intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
|
||||
activation,
|
||||
intermediate_cache2,
|
||||
intermediate_cache1.view(-1, w13_num_shards * N),
|
||||
)
|
||||
|
||||
if output is None:
|
||||
@@ -216,7 +208,7 @@ def fused_marlin_moe(
|
||||
activation: str = "silu",
|
||||
activation_func: Callable[
|
||||
[str, torch.Tensor, torch.Tensor], None
|
||||
] = default_activation_func,
|
||||
] = apply_moe_activation,
|
||||
moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
input_global_scale1: torch.Tensor | None = None,
|
||||
|
||||
@@ -619,30 +619,11 @@ class FusedMoE(CustomOp):
|
||||
# for heuristic purposes, so it must be initialized first.
|
||||
self.quant_method: FusedMoEMethodBase = _get_quant_method()
|
||||
|
||||
if not self.moe_config.is_act_and_mul:
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptFp8MoEMethod,
|
||||
ModelOptNvFp4FusedMoE,
|
||||
if not self.moe_config.is_act_and_mul and not current_platform.is_cuda():
|
||||
raise NotImplementedError(
|
||||
"is_act_and_mul=False is supported only for CUDA for now"
|
||||
)
|
||||
|
||||
if not isinstance(
|
||||
self.quant_method,
|
||||
(
|
||||
UnquantizedFusedMoEMethod,
|
||||
ModelOptFp8MoEMethod,
|
||||
ModelOptNvFp4FusedMoE,
|
||||
),
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"is_act_and_mul=False is supported only for unquantized "
|
||||
", ModelOpt FP8, and ModelOpt NvFp4 checkpoints"
|
||||
)
|
||||
if not current_platform.is_cuda():
|
||||
raise NotImplementedError(
|
||||
"is_act_and_mul=False is supported only for CUDA for now"
|
||||
)
|
||||
|
||||
if self.enable_eplb and not self.quant_method.supports_eplb:
|
||||
# TODO: Add support for additional quantization methods.
|
||||
# The implementation for other quantization methods does not
|
||||
|
||||
@@ -52,7 +52,7 @@ def select_fp8_moe_backend(
|
||||
block_quant: bool,
|
||||
tp_size: int,
|
||||
with_lora_support: bool,
|
||||
is_act_and_mul: bool = True,
|
||||
is_act_and_mul: bool,
|
||||
allow_vllm_cutlass: bool = False,
|
||||
) -> Fp8MoeBackend:
|
||||
"""
|
||||
@@ -128,7 +128,7 @@ def select_fp8_moe_backend(
|
||||
scope="local",
|
||||
)
|
||||
|
||||
if use_deep_gemm and moe_use_deep_gemm and block_quant:
|
||||
if use_deep_gemm and moe_use_deep_gemm and block_quant and is_act_and_mul:
|
||||
if not has_deep_gemm():
|
||||
logger.warning_once(
|
||||
"DeepGEMM backend requested but not available.", scope="local"
|
||||
@@ -141,7 +141,12 @@ def select_fp8_moe_backend(
|
||||
logger.info_once(_make_log_backend("ROCm AITER"), scope="local")
|
||||
return Fp8MoeBackend.AITER
|
||||
|
||||
if allow_vllm_cutlass and not block_quant and cutlass_group_gemm_supported():
|
||||
if (
|
||||
allow_vllm_cutlass
|
||||
and not block_quant
|
||||
and cutlass_group_gemm_supported()
|
||||
and is_act_and_mul
|
||||
):
|
||||
logger.info_once(_make_log_backend("vLLM CUTLASS"), scope="local")
|
||||
return Fp8MoeBackend.VLLM_CUTLASS
|
||||
|
||||
|
||||
@@ -178,6 +178,7 @@ def convert_to_nvfp4_moe_kernel_format(
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
w2_scale_2=w2_scale_2,
|
||||
is_act_and_mul=is_act_and_mul,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown NvFp4 backend for MoE: {nvfp4_backend}")
|
||||
|
||||
@@ -367,7 +367,8 @@ def apply_moe_activation(
|
||||
elif activation == GELU_NO_MUL:
|
||||
output.copy_(F.gelu(input))
|
||||
elif activation == RELU2_NO_MUL:
|
||||
torch.square(F.relu(input), out=output)
|
||||
F.relu(input, inplace=True)
|
||||
torch.square(input, out=output)
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
|
||||
@@ -764,8 +764,6 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert layer.activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
|
||||
@@ -370,12 +370,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer_name: str | None = None,
|
||||
use_marlin: bool = False,
|
||||
):
|
||||
if not moe.is_act_and_mul:
|
||||
raise ValueError(
|
||||
"CompressedTensorsW4A4Nvfp4MoEMethod does not yet "
|
||||
"support non gated MoE models."
|
||||
)
|
||||
|
||||
super().__init__(moe)
|
||||
self.group_size = 16
|
||||
if use_marlin:
|
||||
@@ -388,6 +382,16 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
)
|
||||
else:
|
||||
self.nvfp4_backend = select_nvfp4_moe_backend()
|
||||
|
||||
# TODO: move this type of check into the oracle.
|
||||
if not self.moe.is_act_and_mul and self.nvfp4_backend not in [
|
||||
NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
||||
NvFp4MoeBackend.MARLIN,
|
||||
]:
|
||||
raise NotImplementedError(
|
||||
"Non-gated activations are only supported by FlashInfer "
|
||||
f"CUTLASS and Marlin NvFP4 MoE backends, not {self.nvfp4_backend}."
|
||||
)
|
||||
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
|
||||
self.nvfp4_backend
|
||||
)
|
||||
@@ -404,11 +408,12 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
):
|
||||
layer.num_experts = num_experts
|
||||
layer.params_dtype = params_dtype
|
||||
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
|
||||
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
hidden_size // 2,
|
||||
requires_grad=False,
|
||||
@@ -436,7 +441,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
hidden_size // self.group_size,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
@@ -467,7 +472,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
# Weight Global Scales
|
||||
w13_weight_scale_2 = torch.nn.Parameter(
|
||||
torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False
|
||||
torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2)
|
||||
extra_weight_attrs.update(
|
||||
@@ -486,7 +492,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
# Input Global Scales
|
||||
w13_input_scale = torch.nn.Parameter(
|
||||
torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False
|
||||
torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_input_global_scale", w13_input_scale)
|
||||
extra_weight_attrs.update(
|
||||
@@ -640,6 +647,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=layer.top_k,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
@@ -666,6 +674,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_ids=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
top_k=layer.top_k,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
)
|
||||
else:
|
||||
@@ -722,6 +731,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
block_quant=self.block_quant,
|
||||
tp_size=moe.tp_size,
|
||||
with_lora_support=moe.is_lora_enabled,
|
||||
is_act_and_mul=moe.is_act_and_mul,
|
||||
# TODO(rob): enable selecting this externally.
|
||||
allow_vllm_cutlass=True,
|
||||
)
|
||||
@@ -760,6 +770,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.weight_block_size = None
|
||||
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
|
||||
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size is not None
|
||||
@@ -791,7 +802,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
@@ -814,10 +825,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
# WEIGHT_SCALES
|
||||
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They are combined to a single scale after weight loading.
|
||||
# For gated MoE, allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
# For non-gated MoE, allocate 1 scale for w13.
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
||||
torch.ones(num_experts, w13_num_shards, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
@@ -835,7 +848,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
@@ -858,7 +871,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
|
||||
w13_num_shards
|
||||
* ((intermediate_size_per_partition + block_n - 1) // block_n),
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
@@ -930,11 +944,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
# Per-tensor kernels use a single scale, for W13, but on disk there
|
||||
# is a separate scale for W1 and W3. Requantize with the max scale.
|
||||
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
|
||||
process_fp8_weight_tensor_strategy_moe(
|
||||
w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
|
||||
w13,
|
||||
w13_scale,
|
||||
shard_size=layer.intermediate_size_per_partition,
|
||||
num_experts=layer.num_local_experts,
|
||||
is_act_and_mul=self.moe.is_act_and_mul,
|
||||
)
|
||||
|
||||
w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
|
||||
@@ -1166,12 +1181,13 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
params_dtype = torch.int8
|
||||
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
@@ -1196,7 +1212,10 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
|
||||
num_experts,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
@@ -1296,6 +1315,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
|
||||
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
|
||||
|
||||
# Will transpose the loaded weight along the
|
||||
# intermediate and hidden dim sizes. Will
|
||||
@@ -1307,7 +1327,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size // self.packed_factor,
|
||||
2 * intermediate_size_per_partition,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -1352,7 +1372,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
torch.ones(
|
||||
num_experts,
|
||||
num_groups_w13,
|
||||
2 * intermediate_size_per_partition,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -1600,10 +1620,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert layer.activation == "silu", (
|
||||
f"{layer.activation} not supported for Marlin MoE."
|
||||
)
|
||||
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@@ -1625,6 +1641,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
quant_type_id=self.quant_type.id,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
activation=layer.activation,
|
||||
expert_map=layer.expert_map,
|
||||
g_idx1=layer.w13_weight_g_idx,
|
||||
g_idx2=layer.w2_weight_g_idx,
|
||||
@@ -1675,11 +1692,12 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
extra_weight_attrs.update(
|
||||
{"is_transposed": True, "quant_method": self.strategy}
|
||||
)
|
||||
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size // self.packed_factor,
|
||||
2 * intermediate_size_per_partition,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -1712,7 +1730,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
torch.ones(
|
||||
num_experts,
|
||||
num_groups_w13,
|
||||
2 * intermediate_size_per_partition,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
|
||||
@@ -637,6 +637,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
block_quant=self.block_quant,
|
||||
tp_size=layer.moe_parallel_config.tp_size,
|
||||
with_lora_support=self.moe.is_lora_enabled,
|
||||
is_act_and_mul=self.moe.is_act_and_mul,
|
||||
)
|
||||
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
|
||||
@@ -900,8 +900,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert layer.activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
|
||||
@@ -733,6 +733,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
block_quant=False,
|
||||
tp_size=moe_config.moe_parallel_config.tp_size,
|
||||
with_lora_support=self.moe.is_lora_enabled,
|
||||
is_act_and_mul=self.moe.is_act_and_mul,
|
||||
)
|
||||
self.kernel: mk.FusedMoEModularKernel | None = None
|
||||
|
||||
@@ -789,15 +790,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
if self.moe.is_act_and_mul:
|
||||
w13_up_dim = 2 * intermediate_size_per_partition
|
||||
else:
|
||||
w13_up_dim = intermediate_size_per_partition
|
||||
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
|
||||
|
||||
w13_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
w13_up_dim,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
@@ -826,7 +824,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
# For non-gated MoE, allocate 1 scale for w13.
|
||||
w13_weight_scale = PerTensorScaleParameter(
|
||||
data=torch.full(
|
||||
(num_experts, 2 if self.moe.is_act_and_mul else 1),
|
||||
(num_experts, w13_num_shards),
|
||||
1.0,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
@@ -1132,6 +1130,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
|
||||
self.backend = "cutlass"
|
||||
assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
|
||||
elif envs.VLLM_NVFP4_GEMM_BACKEND == "marlin":
|
||||
self.backend = "marlin"
|
||||
assert is_fp4_marlin_supported(), f"Marlin is required for {self.backend}"
|
||||
|
||||
if self.backend == "none":
|
||||
raise ValueError(
|
||||
@@ -1337,13 +1338,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
self.quant_config = quant_config
|
||||
self.nvfp4_backend = select_nvfp4_moe_backend()
|
||||
# TODO: move this type of check into the oracle.
|
||||
if (
|
||||
not self.moe.is_act_and_mul
|
||||
and not self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS
|
||||
):
|
||||
if not self.moe.is_act_and_mul and self.nvfp4_backend not in [
|
||||
NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
||||
NvFp4MoeBackend.MARLIN,
|
||||
]:
|
||||
raise NotImplementedError(
|
||||
"Non-gated activations are only supported by FlashInfer "
|
||||
"CUTLASS NvFP4 MoE backend."
|
||||
f"CUTLASS and Marlin NvFP4 MoE backends, not {self.nvfp4_backend}."
|
||||
)
|
||||
|
||||
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
|
||||
@@ -1409,11 +1410,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
weight_scale_dtype = torch.float8_e4m3fn
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
global_num_experts = extra_weight_attrs.get("global_num_experts")
|
||||
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
|
||||
# GEMM 1
|
||||
w13_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
(2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
hidden_size // 2,
|
||||
dtype=weight_dtype,
|
||||
@@ -1442,7 +1444,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
w13_weight_scale = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
(2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
|
||||
w13_num_shards * intermediate_size_per_partition,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
hidden_size // self.quant_config.group_size,
|
||||
dtype=weight_scale_dtype,
|
||||
@@ -1472,9 +1474,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
w13_weight_scale_2 = PerTensorScaleParameter(
|
||||
data=torch.empty(
|
||||
num_experts, 2 if self.moe.is_act_and_mul else 1, dtype=torch.float32
|
||||
),
|
||||
data=torch.empty(num_experts, w13_num_shards, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
|
||||
@@ -1495,7 +1495,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
w13_input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(
|
||||
global_sf_num_experts,
|
||||
2 if self.moe.is_act_and_mul else 1,
|
||||
w13_num_shards,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
weight_loader=weight_loader,
|
||||
@@ -1616,6 +1616,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=layer.top_k,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
@@ -1642,6 +1643,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
topk_ids=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
top_k=layer.top_k,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -255,6 +255,7 @@ def flashinfer_trtllm_fp4_moe(
|
||||
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
@@ -269,6 +270,7 @@ def flashinfer_trtllm_fp4_moe(
|
||||
x: Input tensor
|
||||
router_logits: Router logits for expert selection
|
||||
top_k: Number of experts to select per token
|
||||
activation: Activation function to use
|
||||
global_num_experts: Total number of experts across all ranks
|
||||
num_expert_group: Number of expert groups (for grouped routing)
|
||||
topk_group: Top-k within each group
|
||||
@@ -282,6 +284,12 @@ def flashinfer_trtllm_fp4_moe(
|
||||
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2404
|
||||
assert activation == "silu", (
|
||||
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 MoE. "
|
||||
f"{activation} found instead."
|
||||
)
|
||||
|
||||
# Quantize input to FP4
|
||||
if isinstance(x, tuple):
|
||||
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
|
||||
@@ -352,6 +360,7 @@ def flashinfer_trtllm_fp4_routed_moe(
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
top_k: int,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -364,6 +373,7 @@ def flashinfer_trtllm_fp4_routed_moe(
|
||||
x: Input tensor
|
||||
topk_ids: Ids of selected experts
|
||||
top_k: Number of experts to select per token
|
||||
activation: Activation function to use
|
||||
global_num_experts: Total number of experts across all ranks
|
||||
|
||||
Returns:
|
||||
@@ -371,6 +381,12 @@ def flashinfer_trtllm_fp4_routed_moe(
|
||||
"""
|
||||
import flashinfer
|
||||
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2535
|
||||
assert activation == "silu", (
|
||||
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 Routed MoE. "
|
||||
f"{activation} found instead."
|
||||
)
|
||||
|
||||
# Pack top k ids and expert weights into a single int32 tensor, as
|
||||
# required by TRT-LLM
|
||||
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
|
||||
|
||||
@@ -233,8 +233,6 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
|
||||
intermediate_size_per_partition = layer.intermediate_size_per_partition
|
||||
# apply_router_weight_on_input is not supported for moe marlin
|
||||
supports_router_weight = not layer.apply_router_weight_on_input
|
||||
# moe marlin requires the activation to be silu
|
||||
supports_activation = layer.activation == "silu"
|
||||
|
||||
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
|
||||
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
|
||||
@@ -244,12 +242,7 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
|
||||
and intermediate_size_per_partition % max(64, group_size) == 0
|
||||
)
|
||||
supports_group_size = group_size in [-1, 32, 64, 128]
|
||||
return (
|
||||
supports_shape
|
||||
and supports_group_size
|
||||
and supports_router_weight
|
||||
and supports_activation
|
||||
)
|
||||
return supports_shape and supports_group_size and supports_router_weight
|
||||
|
||||
|
||||
def marlin_moe_intermediate_size(w1_packed: torch.Tensor, w2_packed: torch.Tensor):
|
||||
|
||||
@@ -235,6 +235,7 @@ def prepare_nvfp4_moe_layer_for_marlin(
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_scale_2: torch.Tensor,
|
||||
is_act_and_mul: bool,
|
||||
) -> tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
|
||||
]:
|
||||
@@ -266,8 +267,9 @@ def prepare_nvfp4_moe_layer_for_marlin(
|
||||
# Repack weights to marlin format
|
||||
def repack_weight(weight: torch.Tensor, name: str) -> torch.Tensor:
|
||||
tensor_list = []
|
||||
num_shards = 2 if is_act_and_mul else 1
|
||||
if "w13" in name:
|
||||
size_n, size_k = N * 2, K
|
||||
size_n, size_k = N * num_shards, K
|
||||
else:
|
||||
size_n, size_k = K, N
|
||||
|
||||
@@ -300,8 +302,9 @@ def prepare_nvfp4_moe_layer_for_marlin(
|
||||
g_scales = g_scales.to(param_dtype)
|
||||
|
||||
tensor_list = []
|
||||
num_shards = 2 if is_act_and_mul else 1
|
||||
if "w13" in name:
|
||||
size_n, size_k = N * 2, K
|
||||
size_n, size_k = N * num_shards, K
|
||||
else:
|
||||
size_n, size_k = K, N
|
||||
|
||||
|
||||
Reference in New Issue
Block a user