[MoE Refactor] Mxfp4 oracle rebased (#37128)
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -88,8 +88,8 @@ To be used with a particular `FusedMoEPrepareAndFinalizeModular` subclass, MoE k
|
||||
| flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
|
||||
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
|
||||
| marlin | standard,</br>batched | <sup>3</sup> / N/A | <sup>3</sup> / N/A | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
|
||||
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
|
||||
| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] |
|
||||
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmMxfp4ExpertsMonolithic`][vllm.model_executor.layers.fused_moe.experts.trtllm_mxfp4_moe.TrtLlmMxfp4ExpertsMonolithic],</br>[`TrtLlmMxfp4ExpertsModular`][vllm.model_executor.layers.fused_moe.experts.trtllm_mxfp4_moe.TrtLlmMxfp4ExpertsModular],</br>[`TrtLlmNvFp4ExpertsMonolithic`][vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe.TrtLlmNvFp4ExpertsMonolithic],</br>[`TrtLlmNvfp4ExpertsModular`][vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe.TrtLlmNvFp4ExpertsModular] |
|
||||
| rocm aiter moe | standard | mxfp4,</br>fp8 | G(32),G(128),A,T | silu, gelu,</br>swigluoai | Y | N | `rocm_aiter_fused_experts`,</br>`AiterExperts` |
|
||||
| cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] |
|
||||
| naive batched<sup>4</sup> | batched | int8,</br>fp8 | G,A,T | silu, gelu | <sup>6</sup> | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] |
|
||||
|
||||
|
||||
@@ -84,7 +84,10 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
|
||||
# TODO: remove this after finishing migration from envs to model kwargs
|
||||
if model_name == "openai/gpt-oss-20b":
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
|
||||
from .common import is_blackwell
|
||||
|
||||
if is_blackwell():
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
|
||||
|
||||
# Disable, compile cache to make sure custom passes run.
|
||||
# Otherwise, we can't verify fusion happened through the logs.
|
||||
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
|
||||
if not has_triton_kernels():
|
||||
@@ -14,6 +15,7 @@ if not has_triton_kernels():
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
|
||||
import triton_kernels.swiglu
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
from triton_kernels.numerics import InFlexData
|
||||
@@ -303,6 +305,12 @@ def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init):
|
||||
pc2,
|
||||
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
|
||||
|
||||
if current_platform.is_device_capability_family(100):
|
||||
constraints = {
|
||||
"is_persistent": True,
|
||||
}
|
||||
opt_flags.update_opt_flags_constraints(constraints)
|
||||
|
||||
if a_dtype == "bf16" and w_dtype == "mx4":
|
||||
quant_config = mxfp4_w4a16_moe_quant_config(
|
||||
w1_scale=pc1,
|
||||
|
||||
@@ -82,7 +82,7 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
|
||||
model_case.model_id,
|
||||
tensor_parallel_size=model_case.tp,
|
||||
load_format="dummy",
|
||||
cudagraph_capture_sizes=[16],
|
||||
compilation_config={"cudagraph_capture_sizes": [16]},
|
||||
) as llm:
|
||||
# Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562
|
||||
# def check_model(model):
|
||||
|
||||
@@ -17,89 +17,6 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.mxfp4 import (
|
||||
Mxfp4Backend,
|
||||
Mxfp4MoEMethod,
|
||||
)
|
||||
|
||||
|
||||
def _make_mock_moe_config(ep_size: int = 1) -> MagicMock:
|
||||
"""Create a mock FusedMoEConfig with the given EP size."""
|
||||
parallel_config = MagicMock()
|
||||
parallel_config.ep_size = ep_size
|
||||
|
||||
moe_config = MagicMock()
|
||||
moe_config.ep_size = ep_size
|
||||
moe_config.is_lora_enabled = False
|
||||
moe_config.moe_parallel_config = parallel_config
|
||||
return moe_config
|
||||
|
||||
|
||||
class TestMxfp4TritonIsMonolithic:
|
||||
"""Verify that is_monolithic is always True for the TRITON backend,
|
||||
regardless of EP size, since triton_kernel_moe_forward now handles
|
||||
expert_map remapping internally."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"backend,ep_size,expected_monolithic",
|
||||
[
|
||||
# TRITON is always monolithic (handles EP via expert_map remapping)
|
||||
(Mxfp4Backend.TRITON, 1, True),
|
||||
(Mxfp4Backend.TRITON, 2, True),
|
||||
(Mxfp4Backend.TRITON, 4, True),
|
||||
# SM100 backends are always monolithic
|
||||
(Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, 1, True),
|
||||
(Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, 2, True),
|
||||
(Mxfp4Backend.SM100_FI_MXFP4_BF16, 1, True),
|
||||
(Mxfp4Backend.SM100_FI_MXFP4_BF16, 2, True),
|
||||
# MARLIN is never monolithic
|
||||
(Mxfp4Backend.MARLIN, 1, False),
|
||||
(Mxfp4Backend.MARLIN, 2, False),
|
||||
],
|
||||
ids=[
|
||||
"triton-no-ep",
|
||||
"triton-ep2",
|
||||
"triton-ep4",
|
||||
"sm100-trtllm-no-ep",
|
||||
"sm100-trtllm-ep2",
|
||||
"sm100-bf16-no-ep",
|
||||
"sm100-bf16-ep2",
|
||||
"marlin-no-ep",
|
||||
"marlin-ep2",
|
||||
],
|
||||
)
|
||||
@patch(
|
||||
"vllm.model_executor.layers.quantization.mxfp4.get_mxfp4_backend",
|
||||
)
|
||||
@patch(
|
||||
"vllm.model_executor.layers.quantization.mxfp4.get_current_vllm_config",
|
||||
)
|
||||
def test_is_monolithic(
|
||||
self,
|
||||
mock_get_config,
|
||||
mock_get_backend,
|
||||
backend,
|
||||
ep_size,
|
||||
expected_monolithic,
|
||||
):
|
||||
"""is_monolithic should be True for TRITON regardless of EP size."""
|
||||
mock_get_backend.return_value = backend
|
||||
|
||||
mock_compilation_config = MagicMock()
|
||||
mock_compilation_config.max_cudagraph_capture_size = 1024
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.compilation_config = mock_compilation_config
|
||||
mock_get_config.return_value = mock_vllm_config
|
||||
|
||||
moe_config = _make_mock_moe_config(ep_size=ep_size)
|
||||
method = Mxfp4MoEMethod(moe_config)
|
||||
|
||||
assert method.is_monolithic == expected_monolithic, (
|
||||
f"Expected is_monolithic={expected_monolithic} for "
|
||||
f"backend={backend.name}, ep_size={ep_size}, "
|
||||
f"but got {method.is_monolithic}."
|
||||
)
|
||||
|
||||
|
||||
class TestTritonMoeForwardExpertMap:
|
||||
"""Test that triton_kernel_moe_forward applies expert_map remapping
|
||||
|
||||
352
vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py
Normal file
352
vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py
Normal file
@@ -0,0 +1,352 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kMxfp4Static,
|
||||
kMxfp8Dynamic,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
|
||||
class TrtLlmMxfp4ExpertsBase:
|
||||
"""
|
||||
MXFP4 TRTLLM-Gen MoE kernels. Shared base for modular and monolithic.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
# NOTE: FusedMoEExperts.__init__ is called by the concrete subclass
|
||||
# (Monolithic/Modular) via MRO, not here, to avoid mypy issues with
|
||||
# multiple inheritance. This matches the NvFP4 expert pattern.
|
||||
self.moe_config = moe_config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.routing_method_type = moe_config.routing_method
|
||||
self.topk = moe_config.experts_per_token
|
||||
self.intermediate_size_per_partition = (
|
||||
moe_config.intermediate_size_per_partition
|
||||
)
|
||||
self.hidden_dim = moe_config.hidden_dim
|
||||
self.local_num_experts = moe_config.num_local_experts
|
||||
self.ep_rank = moe_config.moe_parallel_config.ep_rank
|
||||
|
||||
# MXFP4-specific TRTLLM parameters
|
||||
device = torch.accelerator.current_device_index()
|
||||
self.gemm1_alpha = torch.tensor(
|
||||
[1.702] * self.local_num_experts,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
self.gemm1_beta = torch.tensor(
|
||||
[1.0] * self.local_num_experts,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
self.gemm1_clamp_limit = torch.tensor(
|
||||
[7.0] * self.local_num_experts,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
self.max_capture_size = (
|
||||
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
|
||||
)
|
||||
|
||||
# P1-5 fix: use public quant_dtype property instead of private _a1
|
||||
self.use_mxfp8_input = quant_config.quant_dtype == "mxfp8"
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
p = current_platform
|
||||
return p.is_cuda() and p.is_device_capability_family(100) and has_flashinfer()
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
SUPPORTED_W_A = [
|
||||
(kMxfp4Static, None),
|
||||
(kMxfp4Static, kMxfp8Dynamic),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation == MoEActivation.SWIGLUOAI
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
# Expert handles MXFP8 quantization internally if needed
|
||||
return True
|
||||
|
||||
|
||||
class TrtLlmMxfp4ExpertsMonolithic(
|
||||
TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsMonolithic
|
||||
):
|
||||
"""
|
||||
Monolithic version of the MXFP4 TRTLLM kernel (router + experts).
|
||||
Wraps flashinfer.trtllm_fp4_block_scale_moe().
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
) -> bool:
|
||||
return (
|
||||
not moe_parallel_config.use_all2all_kernels
|
||||
and not moe_parallel_config.enable_eplb
|
||||
and moe_parallel_config.dp_size <= 1
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
return routing_method in [
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
# Kernel converts to bfloat16 internally
|
||||
return True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
from flashinfer import trtllm_fp4_block_scale_moe
|
||||
|
||||
# Handle input quantization
|
||||
if self.use_mxfp8_input:
|
||||
from flashinfer import mxfp8_quantize
|
||||
|
||||
x_quant, x_scale = mxfp8_quantize(
|
||||
hidden_states,
|
||||
is_sf_swizzled_layout=False,
|
||||
alignment=256,
|
||||
)
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
|
||||
*hidden_states.shape[:-1], -1
|
||||
)
|
||||
else:
|
||||
assert hidden_states.dtype == torch.bfloat16
|
||||
x_quant = hidden_states
|
||||
x_scale = None
|
||||
|
||||
output = torch.empty_like(hidden_states)
|
||||
|
||||
return trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits.to(torch.bfloat16),
|
||||
routing_bias=None,
|
||||
hidden_states=x_quant,
|
||||
hidden_states_scale=x_scale,
|
||||
gemm1_weights=w1,
|
||||
gemm1_weights_scale=self.w1_scale,
|
||||
gemm1_bias=self.w1_bias,
|
||||
gemm1_alpha=self.gemm1_alpha,
|
||||
gemm1_beta=self.gemm1_beta,
|
||||
gemm1_clamp_limit=self.gemm1_clamp_limit,
|
||||
gemm2_weights=w2,
|
||||
gemm2_weights_scale=self.w2_scale,
|
||||
gemm2_bias=self.w2_bias,
|
||||
output1_scale_scalar=None,
|
||||
output1_scale_gate_scalar=None,
|
||||
output2_scale_scalar=None,
|
||||
num_experts=global_num_experts,
|
||||
top_k=self.topk,
|
||||
n_group=None,
|
||||
topk_group=None,
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.ep_rank * self.local_num_experts,
|
||||
local_num_experts=self.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
routing_method_type=self.routing_method_type,
|
||||
do_finalize=True,
|
||||
tune_max_num_tokens=max(self.max_capture_size, 1),
|
||||
output=output,
|
||||
)[0]
|
||||
|
||||
|
||||
class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
"""
|
||||
Modular version of the MXFP4 TRTLLM kernel (just the experts).
|
||||
Wraps flashinfer.trtllm_fp4_block_scale_routed_moe().
|
||||
Moved from trtllm_moe.py.
|
||||
"""
|
||||
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# The workspaces for this implementation are managed by flashinfer.
|
||||
workspace1 = (0,)
|
||||
workspace2 = (0,)
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
topk = topk_ids.size(-1)
|
||||
local_num_experts = w1.size(0)
|
||||
intermediate_size = w2.size(1)
|
||||
local_expert_offset = self.moe_config.ep_rank * local_num_experts
|
||||
|
||||
# Handle input quantization
|
||||
if self.use_mxfp8_input:
|
||||
from flashinfer import mxfp8_quantize
|
||||
|
||||
x_quant, x_scale = mxfp8_quantize(
|
||||
hidden_states,
|
||||
is_sf_swizzled_layout=False,
|
||||
alignment=256,
|
||||
)
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
|
||||
*hidden_states.shape[:-1], -1
|
||||
)
|
||||
else:
|
||||
assert hidden_states.dtype == torch.bfloat16
|
||||
x_quant = hidden_states
|
||||
x_scale = None
|
||||
|
||||
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
|
||||
torch.bfloat16
|
||||
).view(torch.int16)
|
||||
|
||||
assert self.w1_scale is not None
|
||||
assert self.w2_scale is not None
|
||||
kwargs = {
|
||||
"topk_ids": packed_tensor,
|
||||
"routing_bias": None,
|
||||
"hidden_states": x_quant,
|
||||
"hidden_states_scale": x_scale,
|
||||
"gemm1_weights": w1,
|
||||
"gemm1_weights_scale": self.w1_scale,
|
||||
"gemm1_bias": self.w1_bias,
|
||||
"gemm1_alpha": self.gemm1_alpha,
|
||||
"gemm1_beta": self.gemm1_beta,
|
||||
"gemm1_clamp_limit": self.gemm1_clamp_limit,
|
||||
"gemm2_weights": w2,
|
||||
"gemm2_weights_scale": self.w2_scale,
|
||||
"gemm2_bias": self.w2_bias,
|
||||
"output1_scale_scalar": None,
|
||||
"output1_scale_gate_scalar": None,
|
||||
"output2_scale_scalar": None,
|
||||
"num_experts": global_num_experts,
|
||||
"top_k": topk,
|
||||
"n_group": None,
|
||||
"topk_group": None,
|
||||
"intermediate_size": intermediate_size,
|
||||
"local_expert_offset": local_expert_offset,
|
||||
"local_num_experts": local_num_experts,
|
||||
"routed_scaling_factor": None,
|
||||
"routing_method_type": self.routing_method_type,
|
||||
"do_finalize": True,
|
||||
"output": output,
|
||||
"tune_max_num_tokens": max(self.max_capture_size, 1),
|
||||
}
|
||||
|
||||
from flashinfer import trtllm_fp4_block_scale_routed_moe
|
||||
|
||||
from vllm.utils.flashinfer import autotune
|
||||
|
||||
with autotune(False):
|
||||
# Enable autotune when,
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/2023 is
|
||||
# resolved.
|
||||
trtllm_fp4_block_scale_routed_moe(**kwargs)
|
||||
|
||||
return output
|
||||
@@ -40,6 +40,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8Static128BlockSym,
|
||||
kFp8StaticChannelSym,
|
||||
kFp8StaticTensorSym,
|
||||
kMxfp4Static,
|
||||
kNvfp4Static,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
@@ -574,12 +575,13 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
# TODO(rob): add int4, mxfp4, int8 as integrations
|
||||
# TODO(rob): add int4, int8 as integrations
|
||||
# are migrated to use the oracle one-by-one.
|
||||
SUPPORTED_W = [
|
||||
kFp8Static128BlockSym,
|
||||
kFp8StaticChannelSym,
|
||||
kFp8StaticTensorSym,
|
||||
kMxfp4Static,
|
||||
kNvfp4Static,
|
||||
]
|
||||
return weight_key in SUPPORTED_W
|
||||
|
||||
@@ -11,8 +11,10 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
@@ -20,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kMxfp4Static,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
@@ -537,43 +540,43 @@ def make_routing_data(
|
||||
|
||||
|
||||
class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
raise NotImplementedError(
|
||||
"OAITritonExperts is not yet used by an Oracle. "
|
||||
"This method should not be called."
|
||||
)
|
||||
p = current_platform
|
||||
if not p.is_cuda_alike():
|
||||
return False
|
||||
cap = p.get_device_capability()
|
||||
if cap is None:
|
||||
return False
|
||||
# (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
|
||||
# and ROCm gfx942/gfx950 (which map to 9.4/9.5).
|
||||
return (9, 0) <= (cap.major, cap.minor) < (11, 0)
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
raise NotImplementedError(
|
||||
"OAITritonExperts is not yet used by an Oracle. "
|
||||
"This method should not be called."
|
||||
)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
raise NotImplementedError(
|
||||
"OAITritonExperts is not yet used by an Oracle. "
|
||||
"This method should not be called."
|
||||
)
|
||||
SUPPORTED_W_A = [
|
||||
(kMxfp4Static, None),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
raise NotImplementedError(
|
||||
"OAITritonExperts is not yet used by an Oracle. "
|
||||
"This method should not be called."
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
raise NotImplementedError(
|
||||
"OAITritonExperts is not yet used by an Oracle. "
|
||||
"This method should not be called."
|
||||
)
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
@@ -630,6 +633,10 @@ class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
|
||||
class OAITritonExperts(BaseOAITritonExperts):
|
||||
"""OAI Triton-based fused MoE expert implementation."""
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation == MoEActivation.SWIGLUOAI
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
@@ -714,6 +721,15 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
|
||||
One use case for it is to inject LoRA modules on the activation and moe_sum.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation in [
|
||||
MoEActivation.SILU,
|
||||
MoEActivation.GELU,
|
||||
MoEActivation.SWIGLUOAI,
|
||||
MoEActivation.SWIGLUSTEP,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
@@ -839,3 +855,118 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
|
||||
)
|
||||
|
||||
self.moe_sum(intermediate_cache3.view(-1, topk, K), output)
|
||||
|
||||
|
||||
class OAITritonMxfp4ExpertsMonolithic(mk.FusedMoEExpertsMonolithic):
|
||||
"""Monolithic Triton MXFP4 expert. Wraps triton_kernel_moe_forward()."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(moe_config, quant_config)
|
||||
self.topk = moe_config.experts_per_token
|
||||
self.renormalize = moe_config.routing_method in (
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
p = current_platform
|
||||
if not p.is_cuda_alike():
|
||||
return False
|
||||
cap = p.get_device_capability()
|
||||
if cap is None:
|
||||
return False
|
||||
# (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
|
||||
# and ROCm gfx942/gfx950 (which map to 9.4/9.5).
|
||||
return (9, 0) <= (cap.major, cap.minor) < (11, 0)
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
SUPPORTED_W_A = [
|
||||
(kMxfp4Static, None),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation == MoEActivation.SWIGLUOAI
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
) -> bool:
|
||||
return (
|
||||
not moe_parallel_config.use_all2all_kernels
|
||||
and not moe_parallel_config.enable_eplb
|
||||
and moe_parallel_config.dp_size <= 1
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
return routing_method in [
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
# grouped topk + fused topk bias parameters
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
return triton_kernel_moe_forward(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
gating_output=router_logits,
|
||||
topk=self.topk,
|
||||
renormalize=self.renormalize,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.quant_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
@@ -52,7 +52,6 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -218,7 +217,6 @@ def maybe_roundup_hidden_size(
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
is_lora_enabled: bool,
|
||||
model_type: str | None,
|
||||
is_mxfp4_quant: bool,
|
||||
) -> int:
|
||||
"""
|
||||
Given layer hidden size and MoE configurations, round up hidden_size
|
||||
@@ -232,7 +230,6 @@ def maybe_roundup_hidden_size(
|
||||
is used in the case of mxfp4 quantization in selecting the
|
||||
MxFP4Backend.
|
||||
model_type: for checking if gpt-oss
|
||||
is_mxfp4_quant: whether the layer is quantized with mxfp4
|
||||
|
||||
Return:
|
||||
Rounded up hidden_size if rounding up is required based on the configs.
|
||||
@@ -246,28 +243,6 @@ def maybe_roundup_hidden_size(
|
||||
hidden_size, act_dtype, moe_parallel_config
|
||||
)
|
||||
|
||||
# we are padding globally so EP buffer allocation works
|
||||
if model_type == "gpt_oss" and is_mxfp4_quant:
|
||||
from vllm.model_executor.layers.quantization.mxfp4 import (
|
||||
Mxfp4Backend,
|
||||
get_mxfp4_backend,
|
||||
)
|
||||
|
||||
current_mxfp4_backend = get_mxfp4_backend(is_lora_enabled)
|
||||
|
||||
if (
|
||||
current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
|
||||
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||
):
|
||||
hidden_size = round_up(hidden_size, 128)
|
||||
elif (
|
||||
current_platform.is_rocm()
|
||||
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
|
||||
or current_mxfp4_backend == Mxfp4Backend.MARLIN
|
||||
):
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
|
||||
return hidden_size
|
||||
|
||||
|
||||
@@ -540,9 +515,6 @@ class FusedMoE(CustomOp):
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
is_lora_enabled=vllm_config.lora_config is not None,
|
||||
model_type=self.model_type,
|
||||
is_mxfp4_quant=(
|
||||
quant_config is not None and quant_config.is_mxfp4_quant(prefix, self)
|
||||
),
|
||||
)
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
|
||||
847
vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
Normal file
847
vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
Normal file
@@ -0,0 +1,847 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoEConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
mxfp4_mxfp8_moe_quant_config,
|
||||
mxfp4_w4a16_moe_quant_config,
|
||||
ocp_mx_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
_swizzle_mxfp4,
|
||||
get_padding_alignment,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kMxfp4Static,
|
||||
kMxfp8Dynamic,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if has_triton_kernels():
|
||||
try:
|
||||
from triton_kernels.matmul_ogs import PrecisionConfig
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.error(
|
||||
"Failed to import Triton kernels. Please make sure your triton "
|
||||
"version is compatible. Error: %s",
|
||||
e,
|
||||
)
|
||||
|
||||
|
||||
class Mxfp4MoeBackend(Enum):
|
||||
NONE = "None"
|
||||
# FlashInfer TRTLLM backends
|
||||
FLASHINFER_TRTLLM_MXFP4_MXFP8 = "FLASHINFER_TRTLLM_MXFP4_MXFP8"
|
||||
FLASHINFER_TRTLLM_MXFP4_BF16 = "FLASHINFER_TRTLLM_MXFP4_BF16"
|
||||
# FlashInfer CUTLASS backends
|
||||
FLASHINFER_CUTLASS_MXFP4_MXFP8 = "FLASHINFER_CUTLASS_MXFP4_MXFP8"
|
||||
FLASHINFER_CUTLASS_MXFP4_BF16 = "FLASHINFER_CUTLASS_MXFP4_BF16"
|
||||
# Marlin
|
||||
BATCHED_MARLIN = "BATCHED_MARLIN"
|
||||
MARLIN = "MARLIN"
|
||||
# ROCm AITER (CK)
|
||||
CK = "CK"
|
||||
# Triton
|
||||
TRITON = "TRITON"
|
||||
TRITON_UNFUSED = "TRITON_UNFUSED"
|
||||
# XPU
|
||||
XPU = "XPU"
|
||||
|
||||
|
||||
# Backends that share the same TRTLLM weight format
|
||||
TRTLLM_BACKENDS = (
|
||||
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
|
||||
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
|
||||
)
|
||||
|
||||
TRITON_BACKENDS = (
|
||||
Mxfp4MoeBackend.TRITON,
|
||||
Mxfp4MoeBackend.TRITON_UNFUSED,
|
||||
)
|
||||
|
||||
|
||||
def backend_to_kernel_cls(
|
||||
backend: Mxfp4MoeBackend,
|
||||
) -> list[type[mk.FusedMoEExperts]]:
|
||||
if backend in (
|
||||
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
|
||||
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe.experts.trtllm_mxfp4_moe import (
|
||||
TrtLlmMxfp4ExpertsModular,
|
||||
TrtLlmMxfp4ExpertsMonolithic,
|
||||
)
|
||||
|
||||
# NOTE: prefer Monolithic > Modular, so return Monolithic first.
|
||||
return [TrtLlmMxfp4ExpertsMonolithic, TrtLlmMxfp4ExpertsModular]
|
||||
|
||||
elif backend in (
|
||||
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
|
||||
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
)
|
||||
|
||||
return [FlashInferExperts]
|
||||
|
||||
elif backend == Mxfp4MoeBackend.TRITON:
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
OAITritonExperts,
|
||||
OAITritonMxfp4ExpertsMonolithic,
|
||||
)
|
||||
|
||||
# NOTE: prefer Monolithic > Modular, so return Monolithic first.
|
||||
return [OAITritonMxfp4ExpertsMonolithic, OAITritonExperts]
|
||||
|
||||
elif backend == Mxfp4MoeBackend.TRITON_UNFUSED:
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
UnfusedOAITritonExperts,
|
||||
)
|
||||
|
||||
return [UnfusedOAITritonExperts]
|
||||
|
||||
elif backend == Mxfp4MoeBackend.MARLIN:
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
MarlinExperts,
|
||||
)
|
||||
|
||||
return [MarlinExperts]
|
||||
|
||||
elif backend == Mxfp4MoeBackend.BATCHED_MARLIN:
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
BatchedMarlinExperts,
|
||||
)
|
||||
|
||||
return [BatchedMarlinExperts]
|
||||
|
||||
elif backend == Mxfp4MoeBackend.CK:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
AiterExperts,
|
||||
)
|
||||
|
||||
return [AiterExperts]
|
||||
|
||||
elif backend == Mxfp4MoeBackend.XPU:
|
||||
raise NotImplementedError("XPU backend uses XpuMxfp4MoEMethod directly.")
|
||||
else:
|
||||
raise ValueError(f"Unknown MXFP4 MoE backend: {backend.value}")
|
||||
|
||||
|
||||
def map_mxfp4_backend(runner_backend: str) -> Mxfp4MoeBackend:
|
||||
"""Map user's moe_backend string to Mxfp4MoeBackend."""
|
||||
mapping: dict[str, Mxfp4MoeBackend] = {
|
||||
"flashinfer_trtllm": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
|
||||
"flashinfer_trtllm_afp8": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
|
||||
"flashinfer_cutlass": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
|
||||
"flashinfer_cutlass_afp8": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
|
||||
"triton": Mxfp4MoeBackend.TRITON,
|
||||
"marlin": Mxfp4MoeBackend.MARLIN,
|
||||
"ck": Mxfp4MoeBackend.CK,
|
||||
}
|
||||
if backend := mapping.get(runner_backend):
|
||||
return backend
|
||||
raise ValueError(
|
||||
f"moe_backend='{runner_backend}' is not supported for MXFP4 MoE. "
|
||||
f"Expected one of {list(mapping.keys())}."
|
||||
)
|
||||
|
||||
|
||||
def _get_priority_backends() -> list[Mxfp4MoeBackend]:
|
||||
"""
|
||||
Get available backends in priority order based on platform and config.
|
||||
Only includes BF16 backends. MXFP8 backends are selected via env vars.
|
||||
"""
|
||||
_AVAILABLE_BACKENDS = [
|
||||
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
|
||||
Mxfp4MoeBackend.CK,
|
||||
Mxfp4MoeBackend.TRITON,
|
||||
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
|
||||
Mxfp4MoeBackend.TRITON_UNFUSED,
|
||||
Mxfp4MoeBackend.MARLIN,
|
||||
Mxfp4MoeBackend.BATCHED_MARLIN,
|
||||
]
|
||||
return _AVAILABLE_BACKENDS
|
||||
|
||||
|
||||
def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None:
|
||||
"""Map backend to its activation key (MXFP8 or None for BF16)."""
|
||||
if backend in (
|
||||
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
|
||||
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
|
||||
):
|
||||
return kMxfp8Dynamic
|
||||
return None
|
||||
|
||||
|
||||
def select_mxfp4_moe_backend(
|
||||
config: FusedMoEConfig,
|
||||
) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts] | None]:
|
||||
"""
|
||||
Select the primary MXFP4 MoE backend.
|
||||
Note: Shape-specific fallbacks may still occur at runtime.
|
||||
"""
|
||||
triton_kernels_supported = has_triton_kernels() and (
|
||||
9,
|
||||
0,
|
||||
) <= current_platform.get_device_capability() < (11, 0)
|
||||
|
||||
# LoRA: separate experts backend path
|
||||
if config.is_lora_enabled:
|
||||
if not current_platform.is_cuda():
|
||||
raise NotImplementedError("Mxfp4 LoRA only supported on CUDA Platform.")
|
||||
if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported:
|
||||
logger.info_once("Using Triton backend for mxfp4 lora")
|
||||
return Mxfp4MoeBackend.TRITON_UNFUSED, backend_to_kernel_cls(
|
||||
Mxfp4MoeBackend.TRITON_UNFUSED
|
||||
)[0]
|
||||
logger.info_once("Using Marlin backend for mxfp4 lora")
|
||||
return Mxfp4MoeBackend.MARLIN, backend_to_kernel_cls(Mxfp4MoeBackend.MARLIN)[0]
|
||||
|
||||
activation_format = (
|
||||
mk.FusedMoEActivationFormat.BatchedExperts
|
||||
if config.moe_parallel_config.use_batched_activation_format
|
||||
else mk.FusedMoEActivationFormat.Standard
|
||||
)
|
||||
|
||||
def _make_log_backend(backend: Mxfp4MoeBackend):
|
||||
return f"Using '{backend.value}' Mxfp4 MoE backend."
|
||||
|
||||
def _make_log_unsupported(backend: Mxfp4MoeBackend, reason: str | None) -> str:
|
||||
if reason:
|
||||
return (
|
||||
f"Mxfp4 MoE backend '{backend.value}' does not support the "
|
||||
f"deployment configuration since {reason}."
|
||||
)
|
||||
return (
|
||||
f"Mxfp4 MoE backend '{backend.value}' does not support the "
|
||||
"deployment configuration."
|
||||
)
|
||||
|
||||
def _return_or_raise(
|
||||
backend: Mxfp4MoeBackend,
|
||||
config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts]]:
|
||||
reason: str | None = None
|
||||
for k_cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, k_cls
|
||||
raise ValueError(_make_log_unsupported(backend, reason))
|
||||
|
||||
runner_backend = config.moe_backend
|
||||
if runner_backend != "auto":
|
||||
requested_backend = map_mxfp4_backend(runner_backend)
|
||||
if (
|
||||
activation_format == mk.FusedMoEActivationFormat.BatchedExperts
|
||||
and requested_backend == Mxfp4MoeBackend.MARLIN
|
||||
):
|
||||
requested_backend = Mxfp4MoeBackend.BATCHED_MARLIN
|
||||
return _return_or_raise(
|
||||
requested_backend,
|
||||
config,
|
||||
kMxfp4Static,
|
||||
_backend_activation_key(requested_backend),
|
||||
activation_format,
|
||||
)
|
||||
|
||||
# Select kernels in order of backend.
|
||||
AVAILABLE_BACKENDS = _get_priority_backends()
|
||||
|
||||
# Handle explicit FlashInfer MXFP4 BF16 configuration.
|
||||
if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"):
|
||||
if not envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16:
|
||||
AVAILABLE_BACKENDS.remove(Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16)
|
||||
AVAILABLE_BACKENDS.remove(Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16)
|
||||
else:
|
||||
if current_platform.is_device_capability(90):
|
||||
return _return_or_raise(
|
||||
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
|
||||
config,
|
||||
kMxfp4Static,
|
||||
None,
|
||||
activation_format,
|
||||
)
|
||||
if current_platform.is_device_capability_family(100):
|
||||
return _return_or_raise(
|
||||
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
|
||||
config,
|
||||
kMxfp4Static,
|
||||
None,
|
||||
activation_format,
|
||||
)
|
||||
raise ValueError(
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16=1 is set but the "
|
||||
"current device capability is not supported. "
|
||||
"Only SM90 (CUTLASS) and SM100+ (TRTLLM) are supported."
|
||||
)
|
||||
|
||||
# Handle explicit FlashInfer MXFP4 MXFP8 TRTLLM configuration.
|
||||
if (
|
||||
envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
):
|
||||
return _return_or_raise(
|
||||
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
|
||||
config,
|
||||
kMxfp4Static,
|
||||
kMxfp8Dynamic,
|
||||
activation_format,
|
||||
)
|
||||
|
||||
# Handle explicit FlashInfer MXFP4 MXFP8 CUTLASS configuration.
|
||||
if (
|
||||
envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS")
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
|
||||
):
|
||||
return _return_or_raise(
|
||||
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
|
||||
config,
|
||||
kMxfp4Static,
|
||||
kMxfp8Dynamic,
|
||||
activation_format,
|
||||
)
|
||||
|
||||
# Handle explicit Marlin MXFP4 configuration.
|
||||
if envs.is_set("VLLM_MXFP4_USE_MARLIN") and envs.VLLM_MXFP4_USE_MARLIN:
|
||||
return _return_or_raise(
|
||||
Mxfp4MoeBackend.MARLIN,
|
||||
config,
|
||||
kMxfp4Static,
|
||||
None,
|
||||
activation_format,
|
||||
)
|
||||
|
||||
for backend in AVAILABLE_BACKENDS:
|
||||
activation_key = _backend_activation_key(backend)
|
||||
for k_cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls, config, kMxfp4Static, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, k_cls
|
||||
else:
|
||||
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
|
||||
|
||||
if current_platform.is_xpu():
|
||||
backend = Mxfp4MoeBackend.XPU
|
||||
logger.info_once(_make_log_backend(backend))
|
||||
return backend, None
|
||||
|
||||
if current_platform.is_cuda() or current_platform.is_rocm():
|
||||
raise NotImplementedError(
|
||||
"No MXFP4 MoE backend supports the deployment configuration."
|
||||
)
|
||||
|
||||
return Mxfp4MoeBackend.NONE, None
|
||||
|
||||
|
||||
def mxfp4_round_up_hidden_size_and_intermediate_size(
|
||||
backend: Mxfp4MoeBackend, hidden_size: int, intermediate_size: int
|
||||
) -> tuple[int, int]:
|
||||
"""Round up hidden_size and intermediate_size based on backend requirements."""
|
||||
if backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN):
|
||||
intermediate_size = round_up(intermediate_size, 128)
|
||||
if current_platform.is_xpu():
|
||||
hidden_size = round_up(hidden_size, 128)
|
||||
else:
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
elif backend in TRTLLM_BACKENDS:
|
||||
intermediate_size = round_up(intermediate_size, 256)
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
elif backend in (
|
||||
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
|
||||
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
|
||||
):
|
||||
intermediate_size = round_up(intermediate_size, 128)
|
||||
hidden_size = round_up(hidden_size, 128)
|
||||
elif current_platform.is_rocm():
|
||||
pad_align = get_padding_alignment()
|
||||
intermediate_size = round_up(intermediate_size, pad_align)
|
||||
hidden_size = round_up(hidden_size, pad_align)
|
||||
else:
|
||||
intermediate_size = round_up(intermediate_size, 64)
|
||||
return hidden_size, intermediate_size
|
||||
|
||||
|
||||
def convert_to_mxfp4_moe_kernel_format(
|
||||
mxfp4_backend: Mxfp4MoeBackend,
|
||||
layer: torch.nn.Module,
|
||||
w13_weight: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
w13_weight_scale: torch.Tensor,
|
||||
w2_weight_scale: torch.Tensor,
|
||||
w13_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
_cache_permute_indices: dict[torch.Size, torch.Tensor] | None = None,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
Union[torch.Tensor, "PrecisionConfig"],
|
||||
Union[torch.Tensor, "PrecisionConfig"],
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
]:
|
||||
"""Convert loaded weights into backend-specific kernel format."""
|
||||
|
||||
num_experts = w13_weight.shape[0]
|
||||
intermediate_size = w13_weight.shape[1] // 2
|
||||
hidden_size = w13_weight.shape[2] * 2
|
||||
|
||||
sf_block_size = 32 # mxfp4 block size
|
||||
|
||||
if mxfp4_backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN):
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
prepare_moe_mxfp4_layer_for_marlin,
|
||||
)
|
||||
|
||||
return prepare_moe_mxfp4_layer_for_marlin(
|
||||
layer,
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
w13_weight_scale,
|
||||
w2_weight_scale,
|
||||
w13_bias,
|
||||
w2_bias,
|
||||
)
|
||||
|
||||
elif mxfp4_backend in TRTLLM_BACKENDS:
|
||||
assert _cache_permute_indices is not None
|
||||
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
|
||||
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
|
||||
|
||||
# gemm1_alpha/beta/clamp_limit are created by the expert class
|
||||
# (TrtLlmMxfp4ExpertsBase), not on the layer.
|
||||
|
||||
w13_weight = w13_weight.data
|
||||
w2_weight = w2_weight.data
|
||||
w13_weight_scale = w13_weight_scale.data
|
||||
w2_weight_scale = w2_weight_scale.data
|
||||
assert w13_bias is not None and w2_bias is not None
|
||||
w13_bias = w13_bias.data.to(torch.float32)
|
||||
w2_bias = w2_bias.data.to(torch.float32)
|
||||
|
||||
# Swap w1 and w3 as the definition of swiglu is different in trtllm-gen
|
||||
def swap_every_two_rows(x, axis=-1):
|
||||
shape = x.shape
|
||||
if axis < 0:
|
||||
axis = len(shape) + axis
|
||||
new_shape = list(shape)
|
||||
new_shape[axis] = shape[axis] // 2
|
||||
new_shape.insert(axis + 1, 2)
|
||||
x = x.reshape(*new_shape)
|
||||
x = x.flip(axis + 1)
|
||||
new_shape = list(shape)
|
||||
return x.reshape(*new_shape)
|
||||
|
||||
w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
|
||||
w13_weight = swap_every_two_rows(w13_weight, -2)
|
||||
w13_bias = swap_every_two_rows(w13_bias, -1)
|
||||
|
||||
# Shuffle weights and scaling factors for transposed mma output
|
||||
gemm1_weights_shuffled = []
|
||||
gemm1_scales_shuffled = []
|
||||
gemm2_weights_shuffled = []
|
||||
gemm2_scales_shuffled = []
|
||||
gemm1_bias_shuffled = []
|
||||
gemm2_bias_shuffled = []
|
||||
epilogue_tile_m = 128
|
||||
for i in range(num_experts):
|
||||
# w13 weight
|
||||
permute_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w13_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm1_weights_shuffled.append(
|
||||
w13_weight[i]
|
||||
.view(torch.uint8)[permute_indices.to(w13_weight.device)]
|
||||
.contiguous()
|
||||
)
|
||||
# w13 scale
|
||||
permute_sf_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w13_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm1_scales_shuffled.append(
|
||||
nvfp4_block_scale_interleave(
|
||||
w13_weight_scale[i]
|
||||
.view(torch.uint8)[permute_sf_indices.to(w13_weight_scale.device)]
|
||||
.contiguous()
|
||||
)
|
||||
)
|
||||
# w13 bias
|
||||
permute_bias_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w13_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm1_bias_shuffled.append(
|
||||
w13_bias[i]
|
||||
.clone()
|
||||
.reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
|
||||
.contiguous()
|
||||
)
|
||||
# w2 weight
|
||||
permute_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w2_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm2_weights_shuffled.append(
|
||||
w2_weight[i]
|
||||
.view(torch.uint8)[permute_indices.to(w2_weight.device)]
|
||||
.contiguous()
|
||||
)
|
||||
# w2 scale
|
||||
permute_sf_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w2_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm2_scales_shuffled.append(
|
||||
nvfp4_block_scale_interleave(
|
||||
w2_weight_scale[i]
|
||||
.view(torch.uint8)[permute_sf_indices.to(w2_weight_scale.device)]
|
||||
.contiguous()
|
||||
)
|
||||
)
|
||||
# w2 bias
|
||||
permute_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w2_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm2_bias_shuffled.append(
|
||||
w2_bias[i]
|
||||
.clone()
|
||||
.reshape(-1, 1)[permute_indices.to(w2_bias.device)]
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
w13_weight = torch.stack(gemm1_weights_shuffled)
|
||||
w13_weight_scale = (
|
||||
torch.stack(gemm1_scales_shuffled)
|
||||
.reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
w2_weight = torch.stack(gemm2_weights_shuffled)
|
||||
w2_weight_scale = (
|
||||
torch.stack(gemm2_scales_shuffled)
|
||||
.reshape(num_experts, hidden_size, intermediate_size // sf_block_size)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1)
|
||||
w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1)
|
||||
|
||||
return (
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
w13_weight_scale,
|
||||
w2_weight_scale,
|
||||
w13_bias,
|
||||
w2_bias,
|
||||
)
|
||||
|
||||
elif mxfp4_backend in (
|
||||
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
|
||||
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
|
||||
):
|
||||
# De-interleave and swap for w13 weight, bias, and scales
|
||||
w13_w = w13_weight.data
|
||||
gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :]
|
||||
deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1)
|
||||
w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1)
|
||||
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
|
||||
|
||||
assert w13_bias is not None and w2_bias is not None
|
||||
w13_b = w13_bias.data.to(torch.float32)
|
||||
gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2]
|
||||
deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1)
|
||||
b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1)
|
||||
w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
|
||||
|
||||
w13_s = w13_weight_scale.data
|
||||
gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :]
|
||||
deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1)
|
||||
s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1)
|
||||
w13_scale_swapped = torch.cat([s3, s1], dim=1)
|
||||
|
||||
if mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8:
|
||||
from flashinfer import block_scale_interleave
|
||||
|
||||
orig_shape = w13_scale_swapped.shape
|
||||
w13_scale_interleaved = block_scale_interleave(
|
||||
w13_scale_swapped.view(torch.uint8)
|
||||
).reshape(orig_shape)
|
||||
|
||||
w2_s = w2_weight_scale.data
|
||||
orig_shape = w2_s.shape
|
||||
w2_scale_interleaved = block_scale_interleave(
|
||||
w2_s.view(torch.uint8)
|
||||
).reshape(orig_shape)
|
||||
|
||||
return (
|
||||
w13_weight_swapped,
|
||||
w2_weight,
|
||||
w13_scale_interleaved,
|
||||
w2_scale_interleaved,
|
||||
w13_bias_swapped,
|
||||
w2_bias,
|
||||
)
|
||||
|
||||
else:
|
||||
assert mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16
|
||||
|
||||
def _interleave_mxfp4_cutlass_sm90(w):
|
||||
w_shape = w.shape
|
||||
w_interleaved = w.reshape(w_shape[0], w_shape[1], (w_shape[2] // 4), 4)
|
||||
w_interleaved = w_interleaved.permute(0, 2, 1, 3)
|
||||
w_interleaved = w_interleaved.reshape(
|
||||
w_shape[0], w_shape[2] // 4, w_shape[1] * 4
|
||||
)
|
||||
return w_interleaved
|
||||
|
||||
w31_scales = w13_scale_swapped.to(torch.uint8)
|
||||
w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w31_scales)
|
||||
|
||||
w2_scale = w2_weight_scale.data.to(torch.uint8)
|
||||
w2_scale_interleaved = _interleave_mxfp4_cutlass_sm90(w2_scale)
|
||||
|
||||
return (
|
||||
w13_weight_swapped,
|
||||
w2_weight,
|
||||
w31_scales_interleaved,
|
||||
w2_scale_interleaved,
|
||||
w13_bias_swapped,
|
||||
w2_bias,
|
||||
)
|
||||
|
||||
elif mxfp4_backend == Mxfp4MoeBackend.CK:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
if w13_bias is not None:
|
||||
w13_bias = w13_bias.data.to(torch.float32)
|
||||
if w2_bias is not None:
|
||||
w2_bias = w2_bias.data.to(torch.float32)
|
||||
|
||||
e, n, k = w13_weight.shape
|
||||
|
||||
# De-interleave w13 rows: gate/up pairs -> contiguous gate, up blocks
|
||||
w13_weight.view(torch.uint8).copy_(
|
||||
w13_weight.data.view(torch.uint8)
|
||||
.view(e, n // 2, 2, k)
|
||||
.permute(0, 2, 1, 3)
|
||||
.contiguous()
|
||||
.view(e, n, k)
|
||||
)
|
||||
w13_weight_scale.data = (
|
||||
w13_weight_scale.data.view(e, n // 2, 2, -1)
|
||||
.permute(0, 2, 1, 3)
|
||||
.contiguous()
|
||||
.view(e, n, -1)
|
||||
)
|
||||
|
||||
# View as native FP4 dtype for AITER shuffle
|
||||
w13_weight.data = w13_weight.data.view(torch.float4_e2m1fn_x2)
|
||||
w2_weight.data = w2_weight.data.view(torch.float4_e2m1fn_x2)
|
||||
|
||||
# Shuffle weights and scales for AITER CK kernel layout
|
||||
w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w13_weight, 16, True)
|
||||
shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4(
|
||||
w13_weight_scale.view(-1, w13_weight_scale.shape[-1]),
|
||||
num_experts,
|
||||
True,
|
||||
)
|
||||
|
||||
w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w2_weight, 16, False)
|
||||
shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4(
|
||||
w2_weight_scale.view(-1, w2_weight_scale.shape[-1]),
|
||||
num_experts,
|
||||
False,
|
||||
)
|
||||
|
||||
# Permute bias to match de-interleaved weight layout
|
||||
if w13_bias is not None:
|
||||
w13_bias = (
|
||||
w13_bias.data.view(-1, n // 2, 2)
|
||||
.permute(0, 2, 1)
|
||||
.contiguous()
|
||||
.view(-1, n)
|
||||
)
|
||||
|
||||
return (
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
shuffled_w13_scale,
|
||||
shuffled_w2_scale,
|
||||
w13_bias,
|
||||
w2_bias,
|
||||
)
|
||||
|
||||
elif mxfp4_backend in TRITON_BACKENDS:
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
assert w13_bias is not None and w2_bias is not None
|
||||
w13_bias = w13_bias.to(torch.float32)
|
||||
w2_bias = w2_bias.to(torch.float32)
|
||||
|
||||
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
||||
w13_weight,
|
||||
w13_weight_scale,
|
||||
)
|
||||
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
|
||||
w2_weight,
|
||||
w2_weight_scale,
|
||||
)
|
||||
|
||||
w13_precision_config = PrecisionConfig(
|
||||
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
|
||||
)
|
||||
w2_precision_config = PrecisionConfig(
|
||||
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
|
||||
)
|
||||
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
|
||||
return (
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
w13_precision_config,
|
||||
w2_precision_config,
|
||||
w13_bias,
|
||||
w2_bias,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported mxfp4_backend: {mxfp4_backend}: "
|
||||
f"should be one of: {list(Mxfp4MoeBackend)}."
|
||||
)
|
||||
|
||||
|
||||
def make_mxfp4_moe_quant_config(
|
||||
mxfp4_backend: Mxfp4MoeBackend,
|
||||
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
"""Create a FusedMoEQuantConfig for the given MXFP4 backend."""
|
||||
if mxfp4_backend in (
|
||||
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
|
||||
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
|
||||
):
|
||||
return mxfp4_mxfp8_moe_quant_config(
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
elif mxfp4_backend in (
|
||||
Mxfp4MoeBackend.MARLIN,
|
||||
Mxfp4MoeBackend.BATCHED_MARLIN,
|
||||
Mxfp4MoeBackend.TRITON,
|
||||
Mxfp4MoeBackend.TRITON_UNFUSED,
|
||||
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
|
||||
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
|
||||
Mxfp4MoeBackend.CK,
|
||||
):
|
||||
return mxfp4_w4a16_moe_quant_config(
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
else:
|
||||
return ocp_mx_moe_quant_config(
|
||||
quant_dtype="mxfp4",
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
|
||||
|
||||
def make_mxfp4_moe_kernel(
|
||||
moe_quant_config: FusedMoEQuantConfig,
|
||||
moe_config: FusedMoEConfig,
|
||||
experts_cls: type[mk.FusedMoEExperts],
|
||||
mxfp4_backend: Mxfp4MoeBackend,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
) -> mk.FusedMoEKernel:
|
||||
"""Create a FusedMoEKernel for the given MXFP4 backend."""
|
||||
is_monolithic = issubclass(experts_cls, mk.FusedMoEExpertsMonolithic)
|
||||
|
||||
# Create Prepare/Finalize.
|
||||
prepare_finalize = maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=moe_quant_config,
|
||||
routing_tables=routing_tables,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=is_monolithic,
|
||||
)
|
||||
assert prepare_finalize is not None
|
||||
|
||||
logger.info_once("Using %s", prepare_finalize.__class__.__name__, scope="local")
|
||||
|
||||
# Create Experts.
|
||||
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
|
||||
max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
|
||||
assert max_num_tokens is not None
|
||||
experts = experts_cls(
|
||||
moe_config=moe_config,
|
||||
quant_config=moe_quant_config,
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
)
|
||||
else:
|
||||
experts = experts_cls(
|
||||
moe_config=moe_config,
|
||||
quant_config=moe_quant_config,
|
||||
)
|
||||
|
||||
kernel = mk.FusedMoEKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
shared_experts=(
|
||||
shared_experts
|
||||
if moe_config.moe_parallel_config.use_deepep_ll_kernels
|
||||
else None
|
||||
),
|
||||
moe_parallel_config=moe_config.moe_parallel_config,
|
||||
inplace=(
|
||||
not moe_config.disable_inplace and mxfp4_backend not in TRTLLM_BACKENDS
|
||||
),
|
||||
)
|
||||
|
||||
return kernel
|
||||
@@ -14,7 +14,6 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
mxfp4_w4a16_moe_quant_config,
|
||||
nvfp4_moe_quant_config,
|
||||
nvfp4_w4a16_moe_quant_config,
|
||||
)
|
||||
@@ -347,16 +346,6 @@ def convert_to_nvfp4_moe_kernel_format(
|
||||
)
|
||||
|
||||
|
||||
def make_mxfp4_moe_quant_config(
|
||||
w13_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
) -> FusedMoEQuantConfig:
|
||||
return mxfp4_w4a16_moe_quant_config(
|
||||
w1_scale=w13_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
|
||||
|
||||
def make_nvfp4_moe_quant_config(
|
||||
backend: NvFp4MoeBackend,
|
||||
w13_scale: torch.Tensor,
|
||||
|
||||
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8Static128BlockSym,
|
||||
kFp8StaticChannelSym,
|
||||
kFp8StaticTensorSym,
|
||||
kMxfp4Static,
|
||||
)
|
||||
|
||||
|
||||
@@ -201,6 +202,8 @@ def rocm_aiter_fused_experts(
|
||||
activation_method = ActivationMethod.SILU
|
||||
elif activation == MoEActivation.GELU:
|
||||
activation_method = ActivationMethod.GELU
|
||||
elif activation == MoEActivation.SWIGLUOAI:
|
||||
activation_method = rocm_aiter_ops.get_aiter_activation_type("swiglu")
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {activation}")
|
||||
|
||||
@@ -247,8 +250,8 @@ def rocm_aiter_fused_experts(
|
||||
|
||||
else:
|
||||
quant_method = QuantMethod.NO.value
|
||||
# quark moe for mxfp4 w_dtype mxfp4 a_dtype
|
||||
if quant_config.use_mxfp4_w4a4:
|
||||
# mxfp4: both w4a4 (quark) and w4a16 (oracle CK) use BLOCK_1X32
|
||||
if quant_config.use_mxfp4_w4a4 or quant_config.use_mxfp4_w4a16:
|
||||
quant_method = QuantMethod.BLOCK_1X32.value
|
||||
# w8a8 block-scaled
|
||||
if quant_config.block_shape is not None and quant_config.use_fp8_w8a8:
|
||||
@@ -289,6 +292,8 @@ def rocm_aiter_fused_experts(
|
||||
doweight_stage1=apply_router_weight_on_input,
|
||||
num_local_tokens=num_local_tokens,
|
||||
output_dtype=output_dtype,
|
||||
bias1=quant_config.w1_bias if quant_config.use_mxfp4_w4a16 else None,
|
||||
bias2=quant_config.w2_bias if quant_config.use_mxfp4_w4a16 else None,
|
||||
)
|
||||
|
||||
|
||||
@@ -319,21 +324,23 @@ class AiterExperts(mk.FusedMoEExpertsModular):
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
# TODO(rob): AITER also supports MXFP4, which is not
|
||||
# yet supported via an Oracle. Once it is, we will add
|
||||
# MXFP4 to this list.
|
||||
SUPPORTED_W_A = [
|
||||
(None, None),
|
||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
||||
(kFp8StaticTensorSym, kFp8StaticTensorSym),
|
||||
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
|
||||
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
|
||||
(kMxfp4Static, None),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation in [MoEActivation.SILU, MoEActivation.GELU]
|
||||
return activation in [
|
||||
MoEActivation.SILU,
|
||||
MoEActivation.GELU,
|
||||
MoEActivation.SWIGLUOAI,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
|
||||
@@ -1,184 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
)
|
||||
|
||||
|
||||
class TrtLlmGenExperts(mk.FusedMoEExpertsModular):
|
||||
"""TensorRT-LLM-based fused MoE expert implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
max_capture_size,
|
||||
):
|
||||
super().__init__(moe_config, quant_config)
|
||||
self.device = torch.accelerator.current_device_index()
|
||||
self.num_experts = moe_config.num_local_experts
|
||||
self.gemm1_alpha = torch.tensor(
|
||||
[1.702] * self.num_experts, dtype=torch.float32, device=self.device
|
||||
)
|
||||
self.gemm1_beta = torch.tensor(
|
||||
[1.0] * self.num_experts, dtype=torch.float32, device=self.device
|
||||
)
|
||||
self.gemm1_clamp_limit = torch.tensor(
|
||||
[7.0] * self.num_experts, dtype=torch.float32, device=self.device
|
||||
)
|
||||
self.max_capture_size = max_capture_size
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
raise NotImplementedError(
|
||||
"TrtLlmGenExperts is not yet used by an Oracle. "
|
||||
"This method should not be called."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
raise NotImplementedError(
|
||||
"TrtLlmGenExperts is not yet used by an Oracle. "
|
||||
"This method should not be called."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
raise NotImplementedError(
|
||||
"TrtLlmGenExperts is not yet used by an Oracle. "
|
||||
"This method should not be called."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
raise NotImplementedError(
|
||||
"TrtLlmGenExperts is not yet used by an Oracle. "
|
||||
"This method should not be called."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
raise NotImplementedError(
|
||||
"TrtLlmGenExperts is not yet used by an Oracle. "
|
||||
"This method should not be called."
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# The workspaces for this implementation are managed by flashinfer.
|
||||
workspace1 = (0,)
|
||||
workspace2 = (0,)
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
topk = topk_ids.size(-1)
|
||||
local_num_experts = w1.size(0)
|
||||
intermediate_size = w2.size(1)
|
||||
local_expert_offset = self.moe_config.ep_rank * local_num_experts
|
||||
|
||||
x_quant = hidden_states
|
||||
x_scale = a1q_scale
|
||||
if x_scale is not None:
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x_quant.shape[:-1], -1)
|
||||
|
||||
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
|
||||
torch.bfloat16
|
||||
).view(torch.int16)
|
||||
|
||||
assert self.w1_scale is not None
|
||||
assert self.w2_scale is not None
|
||||
kwargs = {
|
||||
"topk_ids": packed_tensor,
|
||||
"routing_bias": None,
|
||||
"hidden_states": x_quant,
|
||||
"hidden_states_scale": x_scale,
|
||||
"gemm1_weights": w1,
|
||||
"gemm1_weights_scale": self.w1_scale,
|
||||
"gemm1_bias": self.w1_bias,
|
||||
"gemm1_alpha": self.gemm1_alpha,
|
||||
"gemm1_beta": self.gemm1_beta,
|
||||
"gemm1_clamp_limit": self.gemm1_clamp_limit,
|
||||
"gemm2_weights": w2,
|
||||
"gemm2_weights_scale": self.w2_scale,
|
||||
"gemm2_bias": self.w2_bias,
|
||||
"output1_scale_scalar": None,
|
||||
"output1_scale_gate_scalar": None,
|
||||
"output2_scale_scalar": None,
|
||||
"num_experts": global_num_experts,
|
||||
"top_k": topk,
|
||||
"n_group": None,
|
||||
"topk_group": None,
|
||||
"intermediate_size": intermediate_size,
|
||||
"local_expert_offset": local_expert_offset,
|
||||
"local_num_experts": local_num_experts,
|
||||
"routed_scaling_factor": None,
|
||||
"routing_method_type": 1,
|
||||
"do_finalize": True,
|
||||
"output": output,
|
||||
"tune_max_num_tokens": max(self.max_capture_size, 1),
|
||||
}
|
||||
|
||||
from flashinfer import trtllm_fp4_block_scale_routed_moe
|
||||
|
||||
from vllm.utils.flashinfer import autotune
|
||||
|
||||
with autotune(False):
|
||||
# Enable autotune when,
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/2023 is
|
||||
# resolved.
|
||||
trtllm_fp4_block_scale_routed_moe(**kwargs)
|
||||
|
||||
return output
|
||||
@@ -45,11 +45,14 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
make_fp8_moe_quant_config,
|
||||
select_fp8_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
|
||||
Mxfp4MoeBackend,
|
||||
make_mxfp4_moe_kernel,
|
||||
make_mxfp4_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
|
||||
NvFp4MoeBackend,
|
||||
convert_to_nvfp4_moe_kernel_format,
|
||||
is_global_sf_supported_for_nvfp4_backend,
|
||||
make_mxfp4_moe_quant_config,
|
||||
make_nvfp4_moe_kernel,
|
||||
make_nvfp4_moe_quant_config,
|
||||
select_nvfp4_moe_backend,
|
||||
@@ -235,7 +238,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
def __init__(self, moe):
|
||||
super().__init__(moe)
|
||||
self.group_size = 32
|
||||
self.mxfp4_backend = NvFp4MoeBackend.MARLIN
|
||||
self.mxfp4_backend = Mxfp4MoeBackend.MARLIN
|
||||
self.experts_cls = MarlinExperts
|
||||
|
||||
def create_weights(
|
||||
@@ -310,7 +313,9 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return make_mxfp4_moe_quant_config(
|
||||
w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale
|
||||
mxfp4_backend=self.mxfp4_backend,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: FusedMoE) -> None:
|
||||
@@ -334,10 +339,11 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config is not None:
|
||||
self.moe_kernel = make_nvfp4_moe_kernel(
|
||||
self.moe_kernel = make_mxfp4_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
experts_cls=self.experts_cls,
|
||||
mxfp4_backend=self.mxfp4_backend,
|
||||
shared_experts=layer.shared_experts,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -25,9 +25,9 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
ocp_mx_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.quantization.mxfp4 import (
|
||||
Mxfp4Backend,
|
||||
get_mxfp4_backend,
|
||||
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
|
||||
Mxfp4MoeBackend,
|
||||
select_mxfp4_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_fp8_moe_layer_for_marlin,
|
||||
@@ -699,9 +699,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
f"Please check that the combination is supported in OCP_MX_Scheme."
|
||||
)
|
||||
|
||||
self.mxfp4_backend: Mxfp4Backend | None = None
|
||||
self.mxfp4_backend: Mxfp4MoeBackend | None = None
|
||||
if self.ocp_mx_scheme == "w_mxfp4":
|
||||
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
|
||||
self.mxfp4_backend, _ = select_mxfp4_moe_backend(moe)
|
||||
|
||||
if self.input_quant is not None:
|
||||
self.static_input_scales = not self.input_quant.get("is_dynamic")
|
||||
|
||||
@@ -389,9 +389,9 @@ def prepare_moe_fp4_layer_for_marlin(
|
||||
|
||||
group_size = 16 if is_nvfp4 else 32
|
||||
|
||||
e = layer.num_experts
|
||||
k = layer.hidden_size
|
||||
n = layer.intermediate_size_per_partition
|
||||
e = layer.moe_config.num_experts
|
||||
k = layer.moe_config.hidden_dim
|
||||
n = layer.moe_config.intermediate_size_per_partition
|
||||
|
||||
# WORKSPACE
|
||||
device = layer.w13_weight.device
|
||||
@@ -500,6 +500,120 @@ def prepare_moe_fp4_layer_for_marlin(
|
||||
setattr(layer, name, bias)
|
||||
|
||||
|
||||
def prepare_moe_mxfp4_layer_for_marlin(
|
||||
layer: torch.nn.Module,
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w13_bias: torch.Tensor | None,
|
||||
w2_bias: torch.Tensor | None,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
]:
|
||||
"""Pure-function version of prepare_moe_fp4_layer_for_marlin for MXFP4.
|
||||
|
||||
Takes weight tensors as inputs and returns transformed tensors.
|
||||
Does NOT modify the layer in-place.
|
||||
"""
|
||||
input_dtype = get_marlin_input_dtype()
|
||||
if (
|
||||
input_dtype is not None
|
||||
and input_dtype.itemsize == 1
|
||||
and input_dtype != torch.float8_e4m3fn
|
||||
):
|
||||
raise RuntimeError("MXFP4 weight + INT8 activation is not supported.")
|
||||
|
||||
group_size = 32 # MXFP4 block size
|
||||
|
||||
# Derive dimensions from actual weight shapes to handle rounded/padded
|
||||
# sizes correctly (e.g., Mxfp4MoEMethod rounds up hidden_dim).
|
||||
# w13 shape: (E, 2*N, K//2)
|
||||
e = w13.shape[0]
|
||||
n = w13.shape[1] // 2 # intermediate_size_per_partition
|
||||
k = w13.shape[2] * 2 # hidden_size
|
||||
|
||||
device = w13.device
|
||||
param_dtype = layer.params_dtype
|
||||
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
|
||||
perm = torch.empty(0, dtype=torch.int, device=device)
|
||||
|
||||
# WEIGHT: Repack weights to marlin format
|
||||
def repack_weight(weight: torch.Tensor, name: str) -> torch.Tensor:
|
||||
tensor_list = []
|
||||
if "w13" in name:
|
||||
size_n, size_k = n * 2, k
|
||||
else:
|
||||
size_n, size_k = k, n
|
||||
|
||||
assert weight.shape == (e, size_n, size_k // 2)
|
||||
|
||||
for i in range(e):
|
||||
qweight = weight[i].view(torch.int32).T.contiguous()
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=qweight,
|
||||
perm=perm,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
num_bits=4,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
tensor_list.append(marlin_qweight)
|
||||
return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
|
||||
w13 = repack_weight(w13, "w13")
|
||||
w2 = repack_weight(w2, "w2")
|
||||
|
||||
# WEIGHT SCALES: Permute scales
|
||||
def permute_scales(scales: torch.Tensor, name: str) -> torch.Tensor:
|
||||
scales = scales.view(torch.float8_e8m0fnu)
|
||||
scales = scales.to(param_dtype)
|
||||
|
||||
tensor_list = []
|
||||
if "w13" in name:
|
||||
size_n, size_k = n * 2, k
|
||||
else:
|
||||
size_n, size_k = k, n
|
||||
|
||||
for i in range(e):
|
||||
scale = scales[i].T
|
||||
marlin_scales = marlin_permute_scales(
|
||||
s=scale,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
marlin_scales = mxfp4_marlin_process_scales(
|
||||
marlin_scales, input_dtype=input_dtype
|
||||
)
|
||||
tensor_list.append(marlin_scales)
|
||||
return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
|
||||
w13_scale = permute_scales(w13_scale, "w13")
|
||||
w2_scale = permute_scales(w2_scale, "w2")
|
||||
|
||||
# BIAS: Permute bias
|
||||
def permute_bias(bias: torch.Tensor | None) -> torch.Tensor | None:
|
||||
if bias is None:
|
||||
return None
|
||||
bias = bias.to(param_dtype)
|
||||
tensor_list = []
|
||||
for i in range(e):
|
||||
tensor_list.append(marlin_permute_bias(bias[i]))
|
||||
return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
|
||||
w13_bias = permute_bias(w13_bias)
|
||||
w2_bias = permute_bias(w2_bias)
|
||||
|
||||
return w13, w2, w13_scale, w2_scale, w13_bias, w2_bias
|
||||
|
||||
|
||||
def rand_marlin_weight_nvfp4_like(weight, group_size, input_dtype=None):
|
||||
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
|
||||
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
@@ -22,7 +20,7 @@ logger = init_logger(__name__)
|
||||
CK_MXFP4_MOE_DIM_ALIGNMENT = 256
|
||||
|
||||
|
||||
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
||||
def _swizzle_mxfp4(quant_tensor, scale, num_warps=8):
|
||||
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
|
||||
assert has_triton_kernels()
|
||||
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
|
||||
@@ -87,35 +85,6 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
||||
return quant_tensor, InFlexData(), scale
|
||||
|
||||
|
||||
def _can_support_mxfp4(
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
scoring_func: str = "softmax",
|
||||
activation: MoEActivation = MoEActivation.SWIGLUOAI,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
):
|
||||
return not (
|
||||
use_grouped_topk
|
||||
or topk_group
|
||||
or num_expert_group
|
||||
or custom_routing_function
|
||||
or e_score_correction_bias
|
||||
or apply_router_weight_on_input
|
||||
or scoring_func != "softmax"
|
||||
or activation != MoEActivation.SWIGLUOAI
|
||||
or expert_load_view
|
||||
or logical_to_physical_map
|
||||
or logical_replica_count
|
||||
)
|
||||
|
||||
|
||||
def get_padding_alignment():
|
||||
return (
|
||||
256
|
||||
|
||||
Reference in New Issue
Block a user