diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index 53f638d50..b5b919c17 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -634,6 +634,46 @@ steps:
- pip install helion
- pytest -v -s kernels/helion/
+
+- label: Kernels FP8 MoE Test (1 H100)
+ timeout_in_minutes: 90
+ gpu: h100
+ num_gpus: 1
+ optional: true
+ commands:
+ - pytest -v -s kernels/moe/test_cutlass_moe.py
+ - pytest -v -s kernels/moe/test_flashinfer.py
+ - pytest -v -s kernels/moe/test_gpt_oss_triton_kernels.py
+ - pytest -v -s kernels/moe/test_modular_oai_triton_moe.py
+ - pytest -v -s kernels/moe/test_moe.py
+ # - pytest -v -s kernels/moe/test_block_fp8.py - failing on main
+ - pytest -v -s kernels/moe/test_block_int8.py
+ - pytest -v -s kernels/moe/test_triton_moe_no_act_mul.py
+ - pytest -v -s kernels/moe/test_triton_moe_ptpc_fp8.py
+
+- label: Kernels FP8 MoE Test (2 H100s)
+ timeout_in_minutes: 90
+ gpu: h100
+ num_gpus: 2
+ optional: true
+ commands:
+ - pytest -v -s kernels/moe/test_deepep_deepgemm_moe.py
+ - pytest -v -s kernels/moe/test_deepep_moe.py
+ - pytest -v -s kernels/moe/test_pplx_cutlass_moe.py
+ # - pytest -v -s kernels/moe/test_pplx_moe.py - failing on main
+
+- label: Kernels Fp4 MoE Test (B200)
+ timeout_in_minutes: 60
+ gpu: b200
+ num_gpus: 1
+ optional: true
+ commands:
+ - pytest -v -s kernels/moe/test_cutedsl_moe.py
+ - pytest -v -s kernels/moe/test_flashinfer_moe.py
+ - pytest -v -s kernels/moe/test_nvfp4_moe.py
+ - pytest -v -s kernels/moe/test_ocp_mx_moe.py
+
+
- label: Model Executor Test # 23min
timeout_in_minutes: 35
torch_nightly: true
diff --git a/benchmarks/kernels/benchmark_cutlass_moe_fp8.py b/benchmarks/kernels/benchmark_cutlass_moe_fp8.py
index 9c6edee7b..f1234d821 100644
--- a/benchmarks/kernels/benchmark_cutlass_moe_fp8.py
+++ b/benchmarks/kernels/benchmark_cutlass_moe_fp8.py
@@ -9,6 +9,7 @@ but use different quantization strategies and backends.
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from tests.kernels.moe.utils import make_dummy_moe_config
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
@@ -138,12 +139,13 @@ def bench_run(
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
- out_dtype=a.dtype,
- e=num_experts,
- n=n,
- k=k,
+ moe_config=make_dummy_moe_config(
+ num_experts=num_experts,
+ hidden_dim=k,
+ intermediate_size_per_partition=n,
+ in_dtype=a.dtype,
+ ),
quant_config=quant_config,
- device=w1.device,
),
)
diff --git a/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py b/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py
index 10a3e3eab..4894d37c4 100644
--- a/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py
+++ b/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py
@@ -12,6 +12,7 @@ import torch
import torch.utils.benchmark as benchmark
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from tests.kernels.moe.utils import make_dummy_moe_config
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
@@ -198,8 +199,7 @@ def bench_run(
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
CutlassExpertsFp4(
- out_dtype=dtype,
- max_experts_per_worker=e,
+ make_dummy_moe_config(),
quant_config=quant_config,
),
)
@@ -244,8 +244,7 @@ def bench_run(
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
CutlassExpertsFp4(
- out_dtype=dtype,
- max_experts_per_worker=e,
+ make_dummy_moe_config(),
quant_config=quant_config,
),
)
diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
index b30a12638..7b5daa62e 100644
--- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
+++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
@@ -6,6 +6,7 @@ import torch.utils.benchmark as benchmark
from benchmark_shapes import WEIGHT_SHAPES_MOE
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from tests.kernels.moe.utils import make_dummy_moe_config
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
@@ -134,13 +135,13 @@ def bench_run(
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
- out_dtype=a.dtype,
- # NOTE(rob): w2 is shaped as [E, hidden, intermediate]
- e=w2.shape[0],
- n=w2.shape[2],
- k=w2.shape[1],
+ moe_config=make_dummy_moe_config(
+ num_experts=w2.shape[0],
+ hidden_dim=w2.shape[1],
+ intermediate_size_per_partition=w2.shape[2],
+ in_dtype=a.dtype,
+ ),
quant_config=quant_config,
- device=w1.device,
),
)
@@ -166,13 +167,13 @@ def bench_run(
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
- out_dtype=a.dtype,
- # NOTE(rob): w2 is shaped as [E, hidden, intermediate]
- e=w2.shape[0],
- n=w2.shape[2],
- k=w2.shape[1],
+ moe_config=make_dummy_moe_config(
+ num_experts=w2.shape[0],
+ hidden_dim=w2.shape[1],
+ intermediate_size_per_partition=w2.shape[2],
+ in_dtype=a.dtype,
+ ),
quant_config=quant_config,
- device=w1.device,
),
)
diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py
index 27a8d9973..90ddee9be 100644
--- a/benchmarks/kernels/benchmark_moe.py
+++ b/benchmarks/kernels/benchmark_moe.py
@@ -16,10 +16,16 @@ import torch
from ray.experimental.tqdm_ray import tqdm
from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
FusedMoEQuantConfig,
+ RoutingMethodType,
_get_config_dtype_str,
)
from vllm.model_executor.layers.fused_moe.fused_moe import *
+from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
+ TritonOrDeepGemmExperts,
+)
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton
@@ -194,10 +200,33 @@ def benchmark_config(
block_shape=block_quant_shape,
)
+ deep_gemm_experts = mk.FusedMoEModularKernel(
+ prepare_finalize=MoEPrepareAndFinalizeNoEP(),
+ fused_experts=TritonOrDeepGemmExperts(
+ moe_config=FusedMoEConfig(
+ num_experts=num_experts,
+ experts_per_token=topk,
+ hidden_dim=hidden_size,
+ intermediate_size_per_partition=shard_intermediate_size,
+ num_local_experts=num_experts,
+ activation="silu",
+ parallel_config=FusedMoEParallelConfig.make_no_parallel(),
+ in_dtype=init_dtype,
+ routing_method=RoutingMethodType.TopK,
+ ),
+ quant_config=quant_config,
+ ),
+ )
+
with override_config(config):
topk_weights, topk_ids, token_expert_indices = fused_topk(
x, input_gating, topk, renormalize=not use_deep_gemm
)
+
+ if use_deep_gemm:
+ return deep_gemm_experts(
+ x, w1, w2, topk_weights, topk_ids, inplace=True
+ )
return fused_experts(
x,
w1,
@@ -206,7 +235,6 @@ def benchmark_config(
topk_ids,
inplace=True,
quant_config=quant_config,
- allow_deep_gemm=use_deep_gemm,
)
# JIT compilation & warmup
diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md
index 448816c28..022c4f2e8 100644
--- a/docs/design/moe_kernel_features.md
+++ b/docs/design/moe_kernel_features.md
@@ -85,10 +85,10 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
|--------|-------------------|--------------|---------------|---------------------|-----------------------|---------|--------|
| triton | standard | all1 | G,A,T | silu, gelu,swigluoai,silu_no_mul,gelu_no_mul | Y | Y | [`fused_experts`][vllm.model_executor.layers.fused_moe.fused_moe.fused_experts],[`TritonExperts`][vllm.model_executor.layers.fused_moe.fused_moe.TritonExperts] |
| triton (batched) | batched | all1 | G,A,T | silu, gelu | 6 | Y | [`BatchedTritonExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedTritonExperts] |
-| deep gemm | standard,batched | fp8 | G(128),A,T | silu, gelu | 6 | Y | [`deep_gemm_moe_fp8`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.deep_gemm_moe_fp8],[`DeepGemmExperts`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.DeepGemmExperts],[`BatchedDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe.BatchedDeepGemmExperts] |
+| deep gemm | standard,batched | fp8 | G(128),A,T | silu, gelu | 6 | Y | [`DeepGemmExperts`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.DeepGemmExperts],[`BatchedDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe.BatchedDeepGemmExperts] |
| cutlass_fp4 | standard,batched | nvfp4 | A,T | silu | Y | Y | [`CutlassExpertsFp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp4] |
| cutlass_fp8 | standard,batched | fp8 | A,T | silu, gelu | Y | Y | [`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] |
-| flashinfer | standard | nvfp4,fp8 | T | 5 | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
+| flashinfer | standard | nvfp4,fp8 | T | 5 | N | Y | [`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
| gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
| marlin | standard,batched | 3 / N/A | 3 / N/A | silu,swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
| trtllm | standard | mxfp4,nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py
index 1f1dafcca..c6f245669 100644
--- a/tests/compile/test_fusion_attn.py
+++ b/tests/compile/test_fusion_attn.py
@@ -43,7 +43,7 @@ from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
- kNvfp4Quant,
+ kNvfp4Dynamic,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
@@ -215,7 +215,7 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
"""Test model for AttentionNvfp4QuantPattern fusion."""
- quant_key = kNvfp4Quant
+ quant_key = kNvfp4Dynamic
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -468,7 +468,7 @@ def test_attention_quant_pattern(
# Note: for fp8, fully_replaced=False because query quant ops remain in graph.
# Only output quant ops are fused into attention.
- test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Quant)
+ test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Dynamic)
# access the underlying `AttnFusionPass` on the `LazyInitPass`
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py
index b71487275..f8ae83ca3 100644
--- a/tests/compile/test_silu_mul_quant_fusion.py
+++ b/tests/compile/test_silu_mul_quant_fusion.py
@@ -44,7 +44,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8StaticTensorSym,
- kNvfp4Quant,
+ kNvfp4Dynamic,
)
from vllm.platforms import current_platform
@@ -134,11 +134,11 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
def ops_in_model_before(self):
return [
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
- QUANT_OPS[kNvfp4Quant],
+ QUANT_OPS[kNvfp4Dynamic],
]
def ops_in_model_after(self):
- return [FUSED_OPS[kNvfp4Quant]]
+ return [FUSED_OPS[kNvfp4Dynamic]]
class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Llama-4-Scout-Fp8-ModelOpt-triton.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Llama-4-Scout-Fp8-ModelOpt-triton.yaml
index 9d62c542a..9e13797bb 100644
--- a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Llama-4-Scout-Fp8-ModelOpt-triton.yaml
+++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Llama-4-Scout-Fp8-ModelOpt-triton.yaml
@@ -3,3 +3,6 @@ accuracy_threshold: 0.92
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel"
+env:
+ VLLM_USE_FLASHINFER_MOE_FP8: "0"
+ VLLM_USE_DEEP_GEMM: "0"
diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ll.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ll.yaml
index 54e6ab7b3..5f10684d2 100644
--- a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ll.yaml
+++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ll.yaml
@@ -6,4 +6,3 @@ server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enab
env:
VLLM_USE_DEEP_GEMM: "1"
VLLM_USE_DEEP_GEMM_MOE: "1"
- VLLM_USE_DEEP_GEMM_E8M0: "0"
diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ll.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ll.yaml
index 1d4cbfe96..37e6039e9 100644
--- a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ll.yaml
+++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ll.yaml
@@ -6,4 +6,3 @@ server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enab
env:
VLLM_USE_DEEP_GEMM: "1"
VLLM_USE_DEEP_GEMM_MOE: "1"
- VLLM_USE_DEEP_GEMM_E8M0: "0"
diff --git a/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-ModelOpt-triton.yaml b/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-ModelOpt-triton.yaml
index 80e279edc..ae6bf6755 100644
--- a/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-ModelOpt-triton.yaml
+++ b/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-ModelOpt-triton.yaml
@@ -3,3 +3,5 @@ accuracy_threshold: 0.92
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"
+env:
+ VLLM_USE_FLASHINFER_MOE_FP8: "0"
diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-cutlass.yaml
index 080c8d338..74820cd28 100644
--- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-cutlass.yaml
+++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-cutlass.yaml
@@ -4,7 +4,5 @@ num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"
env:
- VLLM_USE_DEEP_GEMM: "0"
- VLLM_USE_DEEP_GEMM_MOE: "0"
VLLM_USE_FLASHINFER_MOE_FP8: "1"
VLLM_FLASHINFER_MOE_BACKEND: "throughput"
diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-trtllm.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-trtllm.yaml
index a656cc7c3..d745c9b5b 100644
--- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-trtllm.yaml
+++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-trtllm.yaml
@@ -4,7 +4,5 @@ num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"
env:
- VLLM_USE_DEEP_GEMM: "0"
- VLLM_USE_DEEP_GEMM_MOE: "0"
VLLM_USE_FLASHINFER_MOE_FP8: "1"
VLLM_FLASHINFER_MOE_BACKEND: "latency"
diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-marlin.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-marlin.yaml
index f2273bf2c..c3d86e6bf 100644
--- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-marlin.yaml
+++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-marlin.yaml
@@ -4,6 +4,4 @@ num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"
env:
- VLLM_USE_DEEP_GEMM: "0"
- VLLM_USE_DEEP_GEMM_MOE: "0"
VLLM_TEST_FORCE_FP8_MARLIN: "1"
diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-triton.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-triton.yaml
index ed61e9b89..1b2d72160 100644
--- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-triton.yaml
+++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-triton.yaml
@@ -4,5 +4,5 @@ num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"
env:
+ VLLM_USE_FLASHINFER_MOE_FP8: "0"
VLLM_USE_DEEP_GEMM: "0"
- VLLM_USE_DEEP_GEMM_MOE: "0"
diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-fi-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-fi-cutlass.yaml
index db18dd01b..48ab58c46 100644
--- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-fi-cutlass.yaml
+++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-fi-cutlass.yaml
@@ -4,7 +4,5 @@ num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"
env:
- VLLM_USE_DEEP_GEMM: "0"
- VLLM_USE_DEEP_GEMM_MOE: "0"
VLLM_USE_FLASHINFER_MOE_FP8: "1"
VLLM_FLASHINFER_MOE_BACKEND: "throughput"
diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-marlin.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-marlin.yaml
index 3d82d2e22..46eee7421 100644
--- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-marlin.yaml
+++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-marlin.yaml
@@ -4,6 +4,4 @@ num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"
env:
- VLLM_USE_DEEP_GEMM: "0"
- VLLM_USE_DEEP_GEMM_MOE: "0"
VLLM_TEST_FORCE_FP8_MARLIN: "1"
diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-triton.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-triton.yaml
index 5621217de..3e30d4d15 100644
--- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-triton.yaml
+++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-triton.yaml
@@ -4,5 +4,5 @@ num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"
env:
+ VLLM_USE_FLASHINFER_MOE_FP8: "0"
VLLM_USE_DEEP_GEMM: "0"
- VLLM_USE_DEEP_GEMM_MOE: "0"
diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-CT-vllm-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-CT-vllm-cutlass.yaml
index b1ccadedd..0d7884928 100644
--- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-CT-vllm-cutlass.yaml
+++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-CT-vllm-cutlass.yaml
@@ -3,3 +3,5 @@ accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"
+env:
+ VLLM_USE_FLASHINFER_MOE_FP4: "0"
diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-vllm-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-vllm-cutlass.yaml
index 49a1589fc..a340b6fda 100644
--- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-vllm-cutlass.yaml
+++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-vllm-cutlass.yaml
@@ -3,3 +3,5 @@ accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"
+env:
+ VLLM_USE_FLASHINFER_MOE_FP4: "0"
diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py
index ac6d54b71..350e1be2f 100644
--- a/tests/kernels/moe/modular_kernel_tools/common.py
+++ b/tests/kernels/moe/modular_kernel_tools/common.py
@@ -26,6 +26,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
+ RoutingMethodType,
)
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
@@ -574,10 +575,14 @@ def make_modular_kernel(
num_experts=config.E,
experts_per_token=config.topk,
hidden_dim=config.K,
+ intermediate_size_per_partition=config.N,
num_local_experts=config.num_local_experts,
moe_parallel_config=moe_parallel_config,
in_dtype=config.dtype,
max_num_tokens=next_power_of_2(config.M),
+ activation="silu",
+ device=vllm_config.device_config.device,
+ routing_method=RoutingMethodType.DeepSeekV3,
)
# make modular kernel
diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py
index 99b168dc7..04a654e89 100644
--- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py
+++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py
@@ -425,84 +425,26 @@ def make_fused_experts(
num_dispatchers: int,
N: int,
) -> mk.FusedMoEPermuteExpertsUnpermute:
- batch_kwargs = {
- "max_num_tokens": moe.max_num_tokens,
- "num_dispatchers": num_dispatchers,
- }
- quant_kwargs = {
- "quant_config": quant_config,
- }
- deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
+ if (
+ fused_experts_type.activation_format()
+ == mk.FusedMoEActivationFormat.BatchedExperts
+ ):
+ kwargs = {
+ "moe_config": moe,
+ "quant_config": quant_config,
+ "max_num_tokens": moe.max_num_tokens,
+ "num_dispatchers": num_dispatchers,
+ }
+ else:
+ kwargs = {
+ "moe_config": moe,
+ "quant_config": quant_config,
+ }
torch.set_printoptions(threshold=0, edgeitems=0, linewidth=10000)
- if fused_experts_type == BatchedDeepGemmExperts:
- kwargs = batch_kwargs | quant_kwargs
- print(f"Making BatchedDeepGemmExperts {kwargs} ...")
- experts = BatchedDeepGemmExperts(**kwargs)
- elif fused_experts_type == BatchedTritonExperts:
- kwargs = batch_kwargs | quant_kwargs
- print(f"Making BatchedTritonExperts {kwargs} ...")
- experts = BatchedTritonExperts(**kwargs)
- elif fused_experts_type == DeepGemmExperts:
- print(f"Making DeepGemmExperts {quant_config} ...")
- experts = DeepGemmExperts(quant_config)
- elif fused_experts_type == TritonExperts:
- kwargs = quant_kwargs
- print(f"Making TritonExperts {kwargs} ...")
- experts = TritonExperts(**kwargs)
- elif fused_experts_type == TritonOrDeepGemmExperts:
- kwargs = quant_kwargs | deepgemm_kwargs
- print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
- experts = TritonOrDeepGemmExperts(**kwargs)
- elif fused_experts_type == NaiveBatchedExperts:
- kwargs = batch_kwargs | quant_kwargs
- print(f"Making NaiveBatchedExperts {kwargs} ...")
- experts = NaiveBatchedExperts(**kwargs)
- elif fused_experts_type == CutlassExpertsFp8:
- strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
- kwargs = {
- "out_dtype": moe.in_dtype,
- "ab_strides1": strides[0],
- "ab_strides2": strides[1],
- "c_strides1": strides[2],
- "c_strides2": strides[3],
- } | quant_kwargs
- print(f"Making CutlassExpertsFp8 {kwargs} ...")
- experts = CutlassExpertsFp8(**kwargs)
- elif fused_experts_type == CutlassBatchedExpertsFp8:
- strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
- kwargs = {
- "max_experts_per_worker": moe.num_local_experts,
- "num_dispatchers": num_dispatchers,
- "out_dtype": moe.in_dtype,
- "ab_strides1": strides[0],
- "ab_strides2": strides[1],
- "c_strides1": strides[2],
- "c_strides2": strides[3],
- } | quant_kwargs
- print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...")
- experts = CutlassBatchedExpertsFp8(**kwargs)
- elif fused_experts_type == CutlassExpertsFp4:
- kwargs = {
- "max_experts_per_worker": moe.num_local_experts,
- "num_dispatchers": num_dispatchers,
- "out_dtype": moe.in_dtype,
- } | quant_kwargs
- print(f"Making CutlassExpertsFp4 {kwargs} ...")
- experts = CutlassExpertsFp4(**kwargs)
- elif fused_experts_type == FlashInferExperts:
- kwargs = {
- "out_dtype": moe.in_dtype,
- "ep_rank": moe.ep_rank,
- "ep_size": moe.ep_size,
- "tp_rank": moe.tp_rank,
- "tp_size": moe.tp_size,
- } | quant_kwargs
- print(f"Making FlashInferExperts {kwargs} ...")
- experts = FlashInferExperts(**kwargs)
- else:
- raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}")
+ print(f"Making {fused_experts_type.__class__.__name__} {kwargs} ...")
+ experts = fused_experts_type(**kwargs)
torch.set_printoptions(threshold=1000, edgeitems=5, linewidth=80)
diff --git a/tests/kernels/moe/test_batched_deepgemm.py b/tests/kernels/moe/test_batched_deepgemm.py
index 0ba3d8d4c..081a5fd0b 100644
--- a/tests/kernels/moe/test_batched_deepgemm.py
+++ b/tests/kernels/moe/test_batched_deepgemm.py
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularK
from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported
from .test_deepgemm import make_block_quant_fp8_weights
+from .utils import make_dummy_moe_config
BLOCK_SIZE = [128, 128]
@@ -71,6 +72,7 @@ def test_batched_deepgemm_vs_triton(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
quant_config=quant_config,
+ moe_config=make_dummy_moe_config(),
)
mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts)
@@ -89,6 +91,7 @@ def test_batched_deepgemm_vs_triton(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
quant_config=quant_config,
+ moe_config=make_dummy_moe_config(),
)
mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts)
diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py
index 63fbbfeec..508df9e32 100644
--- a/tests/kernels/moe/test_block_fp8.py
+++ b/tests/kernels/moe/test_block_fp8.py
@@ -4,7 +4,12 @@
import pytest
import torch
-from tests.kernels.moe.utils import make_test_quant_config, make_test_weights
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from tests.kernels.moe.utils import (
+ make_dummy_moe_config,
+ make_test_quant_config,
+ make_test_weights,
+)
from tests.kernels.quant_utils import (
native_per_token_group_quant_fp8,
native_w8a8_block_matmul,
@@ -15,13 +20,21 @@ from vllm.model_executor.layers.fused_moe import (
fused_experts,
fused_topk,
)
+from vllm.model_executor.layers.fused_moe.config import (
+ fp8_w8a8_moe_quant_config,
+)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm_shape,
- deep_gemm_moe_fp8,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
modular_triton_fused_moe,
)
+from vllm.model_executor.layers.fused_moe.prepare_finalize import (
+ MoEPrepareAndFinalizeNoEP,
+)
+from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
+ TritonOrDeepGemmExperts,
+)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
get_mk_alignment_for_contiguous_layout,
@@ -161,7 +174,7 @@ def test_w8a8_block_fp8_fused_moe(
block_shape=block_size,
)
- m_fused_moe = modular_triton_fused_moe(quant_config)
+ m_fused_moe = modular_triton_fused_moe(make_dummy_moe_config(), quant_config)
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
@@ -236,6 +249,29 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
+ quant_config = fp8_w8a8_moe_quant_config(
+ w1_scale=w1_s,
+ w2_scale=w2_s,
+ block_shape=block_size,
+ )
+
+ deep_gemm_experts = mk.FusedMoEModularKernel(
+ prepare_finalize=MoEPrepareAndFinalizeNoEP(),
+ fused_experts=TritonOrDeepGemmExperts(
+ moe_config=make_dummy_moe_config(),
+ quant_config=quant_config,
+ ),
+ )
+
+ def deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids):
+ return deep_gemm_experts(
+ hidden_states=a,
+ w1=w1,
+ w2=w2,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ )
+
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
ref_out = torch_w8a8_block_fp8_moe(
diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py
index d5b1c2cf0..3a5a66a38 100644
--- a/tests/kernels/moe/test_cutlass_moe.py
+++ b/tests/kernels/moe/test_cutlass_moe.py
@@ -8,6 +8,7 @@ import pytest
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from tests.kernels.moe.utils import make_dummy_moe_config
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
@@ -193,16 +194,18 @@ def run_with_expert_maps(
out_tensor = torch.zeros_like(cutlass_moe_kwargs["hidden_states"])
for kwargs, new_quant_config in slice_experts():
+ w2 = kwargs["w2"]
+ a = kwargs["hidden_states"]
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
- out_dtype=kwargs["hidden_states"].dtype,
- # NOTE(rob): w2 is shaped as [E, hidden, intermediate]
- e=kwargs["w2"].shape[0], # type: ignore[union-attr]
- n=kwargs["w2"].shape[2], # type: ignore[union-attr]
- k=kwargs["w2"].shape[1], # type: ignore[union-attr]
+ moe_config=make_dummy_moe_config(
+ num_experts=w2.shape[0],
+ hidden_dim=w2.shape[1],
+ intermediate_size_per_partition=w2.shape[2],
+ in_dtype=a.dtype,
+ ),
quant_config=new_quant_config,
- device="cuda",
),
)
out_tensor = out_tensor + kernel(**kwargs)
@@ -249,19 +252,19 @@ def run_8_bit(
"topk_ids": topk_ids,
}
- num_experts = moe_tensors.w1.size(0)
+ num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined]
with_ep = num_local_experts is not None or num_local_experts == num_experts
if not with_ep:
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
- out_dtype=moe_tensors.a.dtype,
- # NOTE(rob): w2 is shaped as [E, hidden, intermediate]
- e=moe_tensors.w2_q.shape[0], # type: ignore[union-attr]
- n=moe_tensors.w2_q.shape[2], # type: ignore[union-attr]
- k=moe_tensors.w2_q.shape[1], # type: ignore[union-attr]
+ moe_config=make_dummy_moe_config(
+ num_experts=moe_tensors.w2_q.shape[0], # type: ignore[union-attr]
+ hidden_dim=moe_tensors.w2_q.shape[1], # type: ignore[union-attr]
+ intermediate_size_per_partition=moe_tensors.w2_q.shape[2], # type: ignore[union-attr]
+ in_dtype=moe_tensors.a.dtype,
+ ),
quant_config=quant_config,
- device="cuda",
),
)
return kernel(**kwargs)
diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py
index 8987b688a..1bf5ced2e 100644
--- a/tests/kernels/moe/test_deepep_deepgemm_moe.py
+++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py
@@ -33,7 +33,7 @@ from vllm.v1.worker.workspace import init_workspace_manager
from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch
-from .utils import make_test_weights
+from .utils import make_dummy_moe_config, make_test_weights
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import (
@@ -192,6 +192,7 @@ def make_ll_modular_kernel(
max_num_tokens=max_tokens_per_rank,
num_dispatchers=pgi.world_size // dp_size,
quant_config=quant_config,
+ moe_config=make_dummy_moe_config(),
)
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
return mk
@@ -219,7 +220,10 @@ def make_ht_modular_kernel(
block_shape=test_config.block_size,
)
- fused_experts = DeepGemmExperts(quant_config)
+ fused_experts = DeepGemmExperts(
+ moe_config=make_dummy_moe_config(),
+ quant_config=quant_config,
+ )
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
return mk
@@ -349,9 +353,6 @@ def triton_impl(
topk_ids=topk_ids,
inplace=False,
quant_config=quant_config,
- # Make sure this is set to False so we
- # don't end up comparing the same implementation.
- allow_deep_gemm=False,
)
diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py
index e57e0d720..f740f5bf9 100644
--- a/tests/kernels/moe/test_deepep_moe.py
+++ b/tests/kernels/moe/test_deepep_moe.py
@@ -10,11 +10,14 @@ import pytest
import torch.distributed
from torch.distributed import ProcessGroup
+from tests.kernels.moe.utils import make_dummy_moe_config
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import TritonExperts
-from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEQuantConfig,
+)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
@@ -160,15 +163,21 @@ def make_modular_kernel(
num_dispatchers = pgi.world_size // dp_size
+ moe_config = make_dummy_moe_config()
+
if low_latency_mode:
assert not quant_config.per_act_token_quant, "not supported in ll mode"
fused_experts = BatchedTritonExperts(
max_num_tokens=MAX_TOKENS_PER_RANK,
num_dispatchers=num_dispatchers,
+ moe_config=moe_config,
quant_config=quant_config,
)
else:
- fused_experts = TritonExperts(quant_config=quant_config)
+ fused_experts = TritonExperts(
+ moe_config=moe_config,
+ quant_config=quant_config,
+ )
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
return mk
diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py
index 442b561f8..729b54753 100644
--- a/tests/kernels/moe/test_deepgemm.py
+++ b/tests/kernels/moe/test_deepgemm.py
@@ -11,10 +11,19 @@ import math
import pytest
import torch
-from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
-
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from tests.kernels.moe.utils import make_dummy_moe_config
+from vllm.model_executor.layers.fused_moe.config import (
+ fp8_w8a8_moe_quant_config,
+)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
+from vllm.model_executor.layers.fused_moe.prepare_finalize import (
+ MoEPrepareAndFinalizeNoEP,
+)
+from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
+ TritonOrDeepGemmExperts,
+)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
@@ -100,6 +109,14 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
block_shape=block_size,
)
+ deep_gemm_experts = mk.FusedMoEModularKernel(
+ prepare_finalize=MoEPrepareAndFinalizeNoEP(),
+ fused_experts=TritonOrDeepGemmExperts(
+ moe_config=make_dummy_moe_config(),
+ quant_config=quant_config,
+ ),
+ )
+
# triton reference
out_triton = fused_experts(
hidden_states=tokens_bf16,
@@ -109,19 +126,16 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
topk_ids=topk_ids,
inplace=False,
quant_config=quant_config,
- allow_deep_gemm=False,
)
# DeepGemm
- out_deepgemm = fused_experts(
+ out_deepgemm = deep_gemm_experts(
hidden_states=tokens_bf16,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
- quant_config=quant_config,
- allow_deep_gemm=True,
)
diff = calc_diff(out_deepgemm, out_triton)
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
@@ -147,20 +161,19 @@ def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch, workspace_i
with monkeypatch.context() as mp:
mp.setenv("VLLM_USE_DEEP_GEMM", "1")
- _fused_moe_mod = importlib.import_module(
- "vllm.model_executor.layers.fused_moe.fused_moe"
- )
+ _DeepGemmExperts = importlib.import_module(
+ "vllm.model_executor.layers.fused_moe.deep_gemm_moe"
+ ).DeepGemmExperts
call_counter = {"cnt": 0}
- orig_fn = _fused_moe_mod.deep_gemm_moe_fp8
+ orig_fn = _DeepGemmExperts.apply
- def _spy_deep_gemm_moe_fp8(*args, **kwargs):
+ def _spy_apply(*args, **kwargs):
call_counter["cnt"] += 1
return orig_fn(*args, **kwargs)
- monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", _spy_deep_gemm_moe_fp8)
-
+ monkeypatch.setattr(_DeepGemmExperts, "apply", _spy_apply)
if topk > num_experts:
pytest.skip(f"topk={topk} > num_experts={num_experts}")
diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py
index bb2f6b873..e5752e404 100644
--- a/tests/kernels/moe/test_flashinfer.py
+++ b/tests/kernels/moe/test_flashinfer.py
@@ -8,7 +8,10 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
FusedMoEQuantConfig,
+ RoutingMethodType,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
@@ -116,18 +119,7 @@ class TestData:
layer.w13_weight_scale = w13_weight_scale
layer.w2_weight_scale = w2_weight_scale
# Setup dummy config.
- layer.moe_parallel_config = mk.FusedMoEParallelConfig(
- tp_size=1,
- pcp_size=1,
- dp_size=1,
- ep_size=1,
- tp_rank=0,
- pcp_rank=0,
- dp_rank=0,
- ep_rank=0,
- use_ep=False,
- all2all_backend="naive",
- )
+ layer.moe_parallel_config = mk.FusedMoEParallelConfig.make_no_parallel()
# flashinfer expects swapped rows for w13
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
@@ -238,6 +230,8 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
):
set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
+ assert activation in ["silu", "relu2_no_mul"]
+ is_act_and_mul = activation == "silu_and_mul"
with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(
m, k, n, e, is_trtllm=False, activation=activation
@@ -285,19 +279,30 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
td.layer.quant_method = td.layer
+ moe_config = FusedMoEConfig(
+ num_experts=e,
+ experts_per_token=topk,
+ hidden_dim=k,
+ intermediate_size_per_partition=n,
+ num_local_experts=e,
+ activation=activation,
+ device="cuda",
+ moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
+ in_dtype=torch.bfloat16,
+ is_act_and_mul=is_act_and_mul,
+ routing_method=RoutingMethodType.TopK,
+ )
+
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
- defer_input_quant=quant_config.is_block_quantized
+ defer_input_quant=FlashInferExperts.expects_unquantized_inputs(
+ moe_config=moe_config,
+ quant_config=quant_config,
+ )
),
FlashInferExperts(
- out_dtype=td.layer.orig_dtype,
+ moe_config=moe_config,
quant_config=quant_config,
- ep_rank=td.layer.moe_parallel_config.ep_rank,
- ep_size=td.layer.moe_parallel_config.ep_size,
- tp_rank=td.layer.moe_parallel_config.tp_rank,
- tp_size=td.layer.moe_parallel_config.tp_size,
- use_dp=False,
- use_deepseek_fp8_block_scale=False,
),
)
diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py
index ee23e7e65..66ceb0c64 100644
--- a/tests/kernels/moe/test_flashinfer_moe.py
+++ b/tests/kernels/moe/test_flashinfer_moe.py
@@ -13,14 +13,19 @@ from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
+ RoutingMethodType,
+)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
is_valid_flashinfer_cutlass_fused_moe,
)
-from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import (
- create_flashinfer_prepare_finalize,
-)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
+from vllm.model_executor.layers.fused_moe.prepare_finalize import (
+ MoEPrepareAndFinalizeNoEP,
+)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.torch_utils import set_random_seed
@@ -86,9 +91,28 @@ def test_flashinfer_fp4_moe_no_graph(
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
+ moe_config = FusedMoEConfig(
+ num_experts=e,
+ experts_per_token=topk,
+ hidden_dim=k,
+ intermediate_size_per_partition=n,
+ num_local_experts=e,
+ activation=activation,
+ device="cuda",
+ moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
+ in_dtype=dtype,
+ is_act_and_mul=is_gated_act,
+ routing_method=RoutingMethodType.TopK,
+ )
+
flashinfer_experts = FusedMoEModularKernel(
- create_flashinfer_prepare_finalize(use_dp=False, use_nvfp4=True),
- FlashInferExperts(out_dtype=dtype, quant_config=quant_config),
+ MoEPrepareAndFinalizeNoEP(
+ defer_input_quant=FlashInferExperts.expects_unquantized_inputs(
+ moe_config=moe_config,
+ quant_config=quant_config,
+ )
+ ),
+ FlashInferExperts(moe_config=moe_config, quant_config=quant_config),
)
fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation]
diff --git a/tests/kernels/moe/test_modular_oai_triton_moe.py b/tests/kernels/moe/test_modular_oai_triton_moe.py
index 8733ba4d8..38022e0e6 100644
--- a/tests/kernels/moe/test_modular_oai_triton_moe.py
+++ b/tests/kernels/moe/test_modular_oai_triton_moe.py
@@ -36,6 +36,8 @@ from vllm.model_executor.layers.utils import shuffle_weight
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
+from .utils import make_dummy_moe_config
+
MNK = [
(1, 512, 384),
(1, 2880, 2880),
@@ -174,9 +176,9 @@ def oai_triton_moe_impl(
)
if unfused:
- fused_experts = UnfusedOAITritonExperts(quant_config)
+ fused_experts = UnfusedOAITritonExperts(make_dummy_moe_config(), quant_config)
else:
- fused_experts = OAITritonExperts(quant_config)
+ fused_experts = OAITritonExperts(make_dummy_moe_config(), quant_config)
mk = FusedMoEModularKernel(MoEPrepareAndFinalizeNoEP(), fused_experts)
diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py
index 849f9e7f1..e34c78074 100644
--- a/tests/kernels/moe/test_moe.py
+++ b/tests/kernels/moe/test_moe.py
@@ -18,7 +18,7 @@ from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa
-from tests.kernels.moe.utils import fused_moe
+from tests.kernels.moe.utils import fused_moe, make_dummy_moe_config
from tests.kernels.utils import opcheck, stack_and_dev, torch_experts, torch_moe
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, set_current_vllm_config
@@ -332,7 +332,7 @@ def test_fused_moe(
#
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
- m_fused_moe_fn = modular_triton_fused_moe(quant_config)
+ m_fused_moe_fn = modular_triton_fused_moe(make_dummy_moe_config(), quant_config)
def m_fused_moe(
a: torch.Tensor,
@@ -437,7 +437,7 @@ def test_naive_block_assignment_moe(
#
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
- m_fused_moe_fn = modular_triton_fused_moe(quant_config)
+ m_fused_moe_fn = modular_triton_fused_moe(make_dummy_moe_config(), quant_config)
def m_fused_moe(
a: torch.Tensor,
diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py
index 4dd4223db..0011149e3 100644
--- a/tests/kernels/moe/test_nvfp4_moe.py
+++ b/tests/kernels/moe/test_nvfp4_moe.py
@@ -4,7 +4,7 @@ import pytest
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
-from tests.kernels.moe.utils import make_test_weights
+from tests.kernels.moe.utils import make_dummy_moe_config, make_test_weights
from tests.kernels.quantization.nvfp4_utils import (
FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
@@ -92,8 +92,7 @@ def test_cutlass_fp4_moe_no_graph(
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
CutlassExpertsFp4(
- out_dtype=dtype,
- max_experts_per_worker=e,
+ moe_config=make_dummy_moe_config(),
quant_config=quant_config,
),
)
diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py
index a2ab94e37..ef37c1c74 100644
--- a/tests/kernels/moe/test_pplx_cutlass_moe.py
+++ b/tests/kernels/moe/test_pplx_cutlass_moe.py
@@ -9,12 +9,18 @@ from tests.kernels.utils import torch_experts
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk
-from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
+ RoutingMethodType,
+ fp8_w8a8_moe_quant_config,
+)
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassBatchedExpertsFp8
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import set_random_seed
+from vllm.v1.worker.workspace import init_workspace_manager
from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch
@@ -79,6 +85,8 @@ def pplx_cutlass_moe(
PplxPrepareAndFinalize,
)
+ init_workspace_manager(torch.cuda.current_device())
+
assert torch.cuda.current_device() == pgi.local_rank
num_tokens, hidden_dim = a.shape
@@ -132,28 +140,23 @@ def pplx_cutlass_moe(
num_dispatchers=num_dispatchers,
)
- ab_strides1 = torch.full(
- (num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64
- )
- ab_strides2 = torch.full(
- (num_local_experts,), intermediate_dim, device="cuda", dtype=torch.int64
- )
- c_strides1 = torch.full(
- (num_local_experts,), 2 * intermediate_dim, device="cuda", dtype=torch.int64
- )
- c_strides2 = torch.full(
- (num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64
- )
+ def make_moe_config() -> FusedMoEConfig:
+ return FusedMoEConfig(
+ num_experts=num_experts,
+ experts_per_token=topk,
+ hidden_dim=hidden_dim,
+ intermediate_size_per_partition=intermediate_dim,
+ num_local_experts=num_local_experts,
+ moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
+ activation="silu",
+ in_dtype=torch.bfloat16,
+ device="cuda",
+ routing_method=RoutingMethodType.Llama4,
+ )
experts = CutlassBatchedExpertsFp8(
- num_local_experts,
- num_dispatchers,
- out_dtype,
- ab_strides1,
- ab_strides2,
- c_strides1,
- c_strides2,
- fp8_w8a8_moe_quant_config(
+ moe_config=make_moe_config(),
+ quant_config=fp8_w8a8_moe_quant_config(
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
@@ -162,6 +165,8 @@ def pplx_cutlass_moe(
if per_act_token
else a1_scale[rank],
),
+ max_num_tokens=max_num_tokens,
+ num_dispatchers=num_dispatchers,
)
fused_cutlass_experts = FusedMoEModularKernel(
diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py
index c08a54f0e..08519087e 100644
--- a/tests/kernels/moe/test_pplx_moe.py
+++ b/tests/kernels/moe/test_pplx_moe.py
@@ -29,6 +29,7 @@ except ImportError:
from tests.kernels.moe.modular_kernel_tools.parallel_utils import _set_vllm_config
from tests.kernels.moe.utils import (
+ make_dummy_moe_config,
make_shared_experts,
make_test_weights,
naive_batched_moe,
@@ -584,6 +585,7 @@ def pplx_moe(
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=quant_config,
+ moe_config=make_dummy_moe_config(),
)
fused_experts = FusedMoEModularKernel(
diff --git a/tests/kernels/moe/test_routing.py b/tests/kernels/moe/test_routing.py
index 93aa6aa5c..f623f943f 100644
--- a/tests/kernels/moe/test_routing.py
+++ b/tests/kernels/moe/test_routing.py
@@ -6,7 +6,6 @@ import pytest
import torch
from vllm.distributed.eplb.eplb_state import EplbLayerState
-from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.router_factory import (
create_fused_moe_router,
)
@@ -385,17 +384,11 @@ def test_grouped_topk(
global_num_experts,
)
- routing_method_type = None
- if scoring_func == "llama4":
- routing_method_type = RoutingMethodType.Llama4
- scoring_func = "sigmoid"
-
router = create_fused_moe_router(
use_grouped_topk=True,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
- routing_method_type=routing_method_type,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
top_k=top_k,
diff --git a/tests/kernels/moe/test_triton_moe_no_act_mul.py b/tests/kernels/moe/test_triton_moe_no_act_mul.py
index 12d5180f9..ab15f898b 100644
--- a/tests/kernels/moe/test_triton_moe_no_act_mul.py
+++ b/tests/kernels/moe/test_triton_moe_no_act_mul.py
@@ -10,6 +10,7 @@ equals N (not N // 2 like gated activations).
import pytest
import torch
+from tests.kernels.moe.utils import make_dummy_moe_config
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
)
@@ -78,7 +79,10 @@ def test_triton_experts_no_mul_activation(
m, n, k, NUM_EXPERTS, topk
)
- experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG)
+ experts = TritonExperts(
+ moe_config=make_dummy_moe_config(),
+ quant_config=FUSED_MOE_UNQUANTIZED_CONFIG,
+ )
ws1_shape, ws2_shape, out_shape = experts.workspace_shapes(
M=m,
@@ -151,7 +155,10 @@ def test_workspace_shapes_no_mul_vs_gated():
M, N, K, topk = 64, 256, 128, 2
- experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG)
+ experts = TritonExperts(
+ moe_config=make_dummy_moe_config(),
+ quant_config=FUSED_MOE_UNQUANTIZED_CONFIG,
+ )
ws1_no_mul, _, out_no_mul = experts.workspace_shapes(
M, N, K, topk, 8, 8, None, SILU_NO_MUL
@@ -187,7 +194,10 @@ def test_adjust_n_for_activation():
"""Test the adjust_N_for_activation method."""
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
- experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG)
+ experts = TritonExperts(
+ moe_config=make_dummy_moe_config(),
+ quant_config=FUSED_MOE_UNQUANTIZED_CONFIG,
+ )
N = 256
diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py
index f0c8c8033..4883085cb 100644
--- a/tests/kernels/moe/utils.py
+++ b/tests/kernels/moe/utils.py
@@ -8,7 +8,12 @@ from tests.kernels.quant_utils import per_block_cast_to_int8
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
-from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
+ FusedMoEQuantConfig,
+ RoutingMethodType,
+)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize,
BatchedTritonExperts,
@@ -20,6 +25,34 @@ from vllm.utils.deep_gemm import per_block_cast_to_fp8
from vllm.utils.math_utils import round_up
+def make_dummy_moe_config(
+ num_experts: int = 1,
+ experts_per_token: int = 1,
+ hidden_dim: int = 1,
+ intermediate_size_per_partition: int = 1,
+ in_dtype: torch.dtype = torch.bfloat16,
+) -> FusedMoEConfig:
+ """
+ This is a dummy config for the mk constructor interface
+ as most kernels like DeepGEMM, CUTLASSFp4, Triton, MARLIN
+ do not actually use this config.
+
+ CUTLASSFp8 needs to set some params for workshapes.
+ """
+ return FusedMoEConfig(
+ num_experts=num_experts,
+ experts_per_token=experts_per_token,
+ hidden_dim=hidden_dim,
+ intermediate_size_per_partition=intermediate_size_per_partition,
+ num_local_experts=num_experts,
+ moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
+ activation="silu",
+ in_dtype=in_dtype,
+ device="cuda",
+ routing_method=RoutingMethodType.TopK,
+ )
+
+
def triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
@@ -81,6 +114,7 @@ def batched_moe(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
quant_config=quant_config,
+ moe_config=make_dummy_moe_config(),
),
)
@@ -121,6 +155,7 @@ def naive_batched_moe(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
quant_config=quant_config,
+ moe_config=make_dummy_moe_config(),
),
)
diff --git a/tools/vllm-rocm/pin_rocm_dependencies.py b/tools/vllm-rocm/pin_rocm_dependencies.py
index ba11fd934..b9387069d 100644
--- a/tools/vllm-rocm/pin_rocm_dependencies.py
+++ b/tools/vllm-rocm/pin_rocm_dependencies.py
@@ -11,10 +11,11 @@ This ensures that 'pip install vllm' automatically installs the correct custom w
instead of allowing pip to download different versions from PyPI.
"""
-import re
import sys
from pathlib import Path
+import regex as re
+
def extract_version_from_wheel(wheel_name: str) -> str:
"""
diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py
index f0ce5b3db..8530c0dad 100644
--- a/vllm/compilation/activation_quant_fusion.py
+++ b/vllm/compilation/activation_quant_fusion.py
@@ -18,7 +18,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
- kNvfp4Quant,
+ kNvfp4Dynamic,
)
from vllm.platforms import current_platform
@@ -41,7 +41,7 @@ silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr(
torch.ops._C, "silu_and_mul_nvfp4_quant"
)
if silu_and_mul_nvfp4_quant_supported:
- FUSED_OPS[kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
+ FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
class ActivationQuantPattern(ABC):
@@ -129,7 +129,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
"""
def __init__(self) -> None:
- super().__init__(kNvfp4Quant)
+ super().__init__(kNvfp4Dynamic)
def get_inputs(self) -> list[torch.Tensor]:
result = self.empty_quant(5, 32)
diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py
index e3c6c2f20..667828cc6 100644
--- a/vllm/compilation/fusion.py
+++ b/vllm/compilation/fusion.py
@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
- kNvfp4Quant,
+ kNvfp4Dynamic,
kStaticTensorScale,
)
from vllm.platforms import current_platform
@@ -63,7 +63,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
- QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
+ QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py
index 57448aa0b..69dc2e3a6 100644
--- a/vllm/compilation/fusion_attn.py
+++ b/vllm/compilation/fusion_attn.py
@@ -16,7 +16,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
- kNvfp4Quant,
+ kNvfp4Dynamic,
kStaticTensorScale,
)
from vllm.platforms import current_platform
@@ -217,7 +217,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
"""
def __init__(self, layer: Attention, dtype: torch.dtype) -> None:
- super().__init__(layer, kNvfp4Quant, dtype)
+ super().__init__(layer, kNvfp4Dynamic, dtype)
def _register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py
index eda12180d..7bb98db5e 100644
--- a/vllm/compilation/matcher_utils.py
+++ b/vllm/compilation/matcher_utils.py
@@ -21,7 +21,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
- kNvfp4Quant,
+ kNvfp4Dynamic,
)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
@@ -38,7 +38,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
- QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
+ QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
index e598ec3ac..ac37cff93 100644
--- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
@@ -7,11 +7,20 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger
-from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
+ FusedMoEQuantConfig,
+)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ kFp8Dynamic128Sym,
+ kFp8Static128BlockSym,
+)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (
@@ -19,6 +28,7 @@ from vllm.utils.deep_gemm import (
fp8_m_grouped_gemm_nt_masked,
get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used,
+ is_deep_gemm_supported,
)
from vllm.utils.math_utils import cdiv, round_up
@@ -253,29 +263,52 @@ def persistent_masked_m_silu_mul_quant(
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
+ moe_config: FusedMoEConfig,
+ quant_config: FusedMoEQuantConfig,
max_num_tokens: int,
num_dispatchers: int,
- quant_config: FusedMoEQuantConfig,
):
"""
max_num_tokens: Maximum number of tokens from a DP Rank
num_dispatchers: The number of DP dispatchers.
quant_config: Quantization configuration
"""
- super().__init__(quant_config)
+ super().__init__(
+ moe_config=moe_config,
+ quant_config=quant_config,
+ max_num_tokens=max_num_tokens,
+ num_dispatchers=num_dispatchers,
+ )
assert self.block_shape == get_mk_alignment_for_contiguous_layout()
assert self.quant_config.use_fp8_w8a8
- self.max_num_tokens = max_num_tokens
- self.num_dispatchers = num_dispatchers
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.BatchedExperts,
- mk.FusedMoEActivationFormat.BatchedExperts,
- )
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.BatchedExperts
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ return is_deep_gemm_supported()
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ return False
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ SUPPORTED_W_A = [(kFp8Static128BlockSym, kFp8Dynamic128Sym)]
+ return (weight_key, activation_key) in SUPPORTED_W_A
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ return activation in ["silu"]
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ return True
def supports_chunking(self) -> bool:
return False
@@ -310,6 +343,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
+ assert self.num_dispatchers is not None
+ assert self.max_num_tokens is not None
num_dispatchers = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens
diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py
index c8baefbd5..12e9918e0 100644
--- a/vllm/model_executor/layers/fused_moe/config.py
+++ b/vllm/model_executor/layers/fused_moe/config.py
@@ -862,6 +862,7 @@ class FusedMoEParallelConfig:
use_ep: bool # whether to use EP or not
all2all_backend: str # all2all backend for MoE communication
+ enable_eplb: bool # whether to enable expert load balancing
@property
def use_all2all_kernels(self):
@@ -882,6 +883,16 @@ class FusedMoEParallelConfig:
def use_deepep_ll_kernels(self):
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
+ @property
+ def use_batched_activation_format(self):
+ return self.use_deepep_ll_kernels or self.use_pplx_kernels
+
+ @property
+ def use_naive_all2all_kernels(self):
+ return self.use_all2all_kernels and (
+ self.all2all_backend in ["naive", "allgather_reducescatter"]
+ )
+
@staticmethod
def flatten_tp_across_dp_and_pcp(
tp_size: int, dp_size: int, dp_rank: int, pcp_size: int, pcp_rank: int
@@ -999,6 +1010,7 @@ class FusedMoEParallelConfig:
ep_rank=0,
use_ep=False,
all2all_backend=vllm_parallel_config.all2all_backend,
+ enable_eplb=vllm_parallel_config.enable_eplb,
)
# DP + EP / TP + EP / DP + TP + EP
assert use_ep
@@ -1017,6 +1029,24 @@ class FusedMoEParallelConfig:
ep_rank=ep_rank,
use_ep=True,
all2all_backend=vllm_parallel_config.all2all_backend,
+ enable_eplb=vllm_parallel_config.enable_eplb,
+ )
+
+ @classmethod
+ def make_no_parallel(cls) -> "FusedMoEParallelConfig":
+ """For usage in CI/CD and testing."""
+ return FusedMoEParallelConfig(
+ tp_size=1,
+ tp_rank=0,
+ pcp_size=1,
+ pcp_rank=0,
+ dp_size=1,
+ dp_rank=0,
+ ep_size=1,
+ ep_rank=0,
+ use_ep=False,
+ all2all_backend="naive",
+ enable_eplb=False,
)
@@ -1026,8 +1056,11 @@ class FusedMoEConfig:
num_experts: int
experts_per_token: int
hidden_dim: int
-
+ intermediate_size_per_partition: int
num_local_experts: int
+ activation: str
+ device: torch.device | str
+ routing_method: RoutingMethodType
moe_parallel_config: FusedMoEParallelConfig
# The activation type.
diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py
index acc12d0da..0d8690638 100644
--- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py
+++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py
@@ -7,7 +7,11 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
-from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
+ FusedMoEQuantConfig,
+)
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_permute,
moe_unpermute,
@@ -23,6 +27,19 @@ from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
apply_moe_activation,
)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ kFp8DynamicTensorSym,
+ kFp8DynamicTokenSym,
+ kFp8StaticChannelSym,
+ kFp8StaticTensorSym,
+ kNvfp4Dynamic,
+ kNvfp4Static,
+)
+from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
+ cutlass_group_gemm_supported,
+)
+from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@@ -238,29 +255,57 @@ def run_cutlass_moe_fp8(
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
- e: int,
- n: int,
- k: int,
- out_dtype: torch.dtype | None,
+ moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
- device: torch.dtype,
+ max_num_tokens: int | None = None,
+ num_dispatchers: int | None = None,
):
+ super().__init__(
+ moe_config=moe_config,
+ quant_config=quant_config,
+ max_num_tokens=max_num_tokens,
+ num_dispatchers=num_dispatchers,
+ )
assert quant_config.use_fp8_w8a8
- super().__init__(quant_config)
- # E: num_experts
- # N: intermediate size per partition
- # K: hidden dim
+ e = moe_config.num_local_experts
+ n = moe_config.intermediate_size_per_partition
+ k = moe_config.hidden_dim
+ device = moe_config.device
ab_strides1_c_strides2 = torch.full((e,), k, device=device, dtype=torch.int64)
ab_strides2 = torch.full((e,), n, device=device, dtype=torch.int64)
c_strides1 = torch.full((e,), 2 * n, device=device, dtype=torch.int64)
- self.out_dtype = out_dtype
+ self.out_dtype = moe_config.in_dtype
self.ab_strides1 = ab_strides1_c_strides2
self.ab_strides2 = ab_strides2
self.c_strides1 = c_strides1
self.c_strides2 = ab_strides1_c_strides2
+ @staticmethod
+ def _supports_current_device() -> bool:
+ return cutlass_group_gemm_supported()
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ return False
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ SUPPORTED_W_A = [
+ (kFp8StaticChannelSym, kFp8DynamicTokenSym),
+ (kFp8StaticTensorSym, kFp8DynamicTensorSym),
+ (kFp8StaticTensorSym, kFp8StaticTensorSym),
+ ]
+ return (weight_key, activation_key) in SUPPORTED_W_A
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ return activation in ["silu", "gelu", "swigluoai"]
+
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
@@ -291,7 +336,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
expert_num_tokens = expert_tokens_meta.expert_num_tokens
use_batched_format = (
- self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts
+ self.activation_format() == mk.FusedMoEActivationFormat.BatchedExperts
)
in_dtype = hidden_states.dtype
@@ -324,20 +369,23 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
class CutlassExpertsFp8(CutlassExpertsFp8Base):
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.Standard,
- mk.FusedMoEActivationFormat.Standard,
- )
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ # CutlassExpertsFp8 does not support expert map, which is
+ # needed for STANDARD activation format kernels in DP/EP mode.
+ # Note that the BATCHED activation format does not use
+ # the expert map for identifying experts.
+ return not moe_parallel_config.use_all2all_kernels
def supports_chunking(self) -> bool:
return True
def supports_expert_map(self) -> bool:
- return True
+ return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# topk weights and reduction are fused in moe_unpermute cuda kernel
@@ -365,26 +413,16 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
- def __init__(
- self,
- max_experts_per_worker: int,
- num_dispatchers: int,
- *args,
- **kwargs,
- ):
- super().__init__(*args, **kwargs)
- assert max_experts_per_worker > 0
- self.max_experts_per_worker = max_experts_per_worker
- self.num_dispatchers = num_dispatchers
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ # BATCHED activation format works with EP because
+ # expert_map is not used to identify experts (the
+ # info is encoded/managed by the P/F logic).
+ return True
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.BatchedExperts,
- mk.FusedMoEActivationFormat.BatchedExperts,
- )
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.BatchedExperts
def supports_chunking(self) -> bool:
return False
@@ -408,14 +446,15 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dp = self.num_dispatchers
assert num_dp is not None
+ experts_per_worker = self.moe_config.num_local_experts
activation_out_dim = self.adjust_N_for_activation(N, activation)
- workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K))
+ workspace1 = (experts_per_worker, M * num_dp, max(N, K))
workspace2 = (
- self.max_experts_per_worker,
+ experts_per_worker,
M * num_dp,
max(activation_out_dim, K),
)
- output = (self.max_experts_per_worker, M, K)
+ output = (experts_per_worker, M, K)
return (workspace1, workspace2, output)
@@ -601,34 +640,41 @@ def run_cutlass_moe_fp4(
return
-# Split into batched and non-batched
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
- def __init__(
- self,
- max_experts_per_worker: int,
- out_dtype: torch.dtype,
- quant_config: FusedMoEQuantConfig,
- use_batched_format: bool = False,
- ):
- super().__init__(quant_config)
- self.max_experts_per_worker = max_experts_per_worker
- self.out_dtype = out_dtype
- self.use_batched_format = use_batched_format
+ @staticmethod
+ def expects_unquantized_inputs(
+ moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig
+ ) -> bool:
+ return True
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- if self.use_batched_format:
- return (
- mk.FusedMoEActivationFormat.BatchedExperts,
- mk.FusedMoEActivationFormat.BatchedExperts,
- )
- else:
- return (
- mk.FusedMoEActivationFormat.Standard,
- mk.FusedMoEActivationFormat.Standard,
- )
+ @staticmethod
+ def _supports_current_device() -> bool:
+ return current_platform.has_device_capability((10, 0))
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ return False
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ return (weight_key, activation_key) == (kNvfp4Static, kNvfp4Dynamic)
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ return activation in ["silu", "gelu", "swigluoai"]
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ # CutlassExpertsFp4 does not support expert map, which is
+ # needed for STANDARD activation format kernels in EP mode.
+ return moe_parallel_config.ep_size == 1
+
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
def supports_expert_map(self) -> bool:
return False
@@ -640,7 +686,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
return TopKWeightAndReduceNoOP()
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
- return self.out_dtype if self.out_dtype is not None else act_dtype
+ return act_dtype
def workspace_shapes(
self,
@@ -653,18 +699,9 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
- activation_out_dim = self.adjust_N_for_activation(N, activation)
- workspace1: tuple[int, ...] = ()
- workspace2: tuple[int, ...] = ()
- output: tuple[int, ...] = ()
- if self.use_batched_format:
- workspace1 = (self.max_experts_per_worker, M, max(N, K))
- workspace2 = (self.max_experts_per_worker, M, activation_out_dim)
- output = (self.max_experts_per_worker, M, K)
- else:
- workspace1 = (M * topk, max(2 * N, K))
- workspace2 = (M * topk, N)
- output = (M, K)
+ workspace1 = (M * topk, max(2 * N, K))
+ workspace2 = (M * topk, N)
+ output = (M, K)
return (workspace1, workspace2, output)
def apply(
@@ -869,10 +906,11 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
c_strides2: torch.Tensor,
s_strides1: torch.Tensor,
s_strides2: torch.Tensor,
+ moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
group_size: int,
):
- super().__init__(quant_config)
+ super().__init__(moe_config=moe_config, quant_config=quant_config)
self.out_dtype = out_dtype
self.a_strides1 = a_strides1
self.a_strides2 = a_strides2
@@ -884,13 +922,46 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
self.s_strides2 = s_strides2
self.group_size = group_size
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.Standard,
- mk.FusedMoEActivationFormat.Standard,
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ raise NotImplementedError(
+ "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ raise NotImplementedError(
+ "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ raise NotImplementedError(
+ "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ raise NotImplementedError(
+ "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ raise NotImplementedError(
+ "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
+ "This method should not be called."
)
def supports_chunking(self) -> bool:
@@ -947,7 +1018,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
expert_num_tokens = None
use_batched_format = (
- self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts
+ self.activation_format() == mk.FusedMoEActivationFormat.BatchedExperts
)
assert not use_batched_format, "batched format not supported"
@@ -1003,6 +1074,7 @@ def cutlass_moe_w4a8_fp8(
s_strides1: torch.Tensor,
s_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig,
+ moe_config: FusedMoEConfig,
activation: str = "silu",
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
@@ -1076,6 +1148,7 @@ def cutlass_moe_w4a8_fp8(
c_strides2=c_strides2,
s_strides1=s_strides1,
s_strides2=s_strides2,
+ moe_config=moe_config,
quant_config=quant_config,
group_size=group_size,
),
diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
index a2e5a07fb..1d5d039b6 100644
--- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
@@ -6,17 +6,15 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
FusedMoEQuantConfig,
- fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
compute_aligned_M,
deepgemm_moe_permute,
deepgemm_unpermute_and_reduce,
)
-from vllm.model_executor.layers.fused_moe.prepare_finalize import (
- MoEPrepareAndFinalizeNoEP,
-)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
@@ -26,9 +24,15 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8_packed_for_deepgemm,
silu_mul_per_token_group_quant_fp8_colmajor,
)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ kFp8Dynamic128Sym,
+ kFp8Static128BlockSym,
+)
from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
get_mk_alignment_for_contiguous_layout,
+ is_deep_gemm_supported,
m_grouped_fp8_gemm_nt_contiguous,
)
from vllm.utils.import_utils import has_deep_gemm
@@ -109,21 +113,42 @@ def _valid_deep_gemm(
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
- def __init__(self, quant_config: FusedMoEQuantConfig):
- super().__init__(quant_config)
+ def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
+ super().__init__(moe_config=moe_config, quant_config=quant_config)
assert quant_config.block_shape == get_mk_alignment_for_contiguous_layout()
assert quant_config.quant_dtype == torch.float8_e4m3fn
assert not quant_config.per_act_token_quant
assert not quant_config.per_out_ch_quant
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.Standard,
- mk.FusedMoEActivationFormat.Standard,
- )
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ return is_deep_gemm_supported()
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ return False
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ SUPPORTED_W_A = [
+ (kFp8Static128BlockSym, kFp8Dynamic128Sym),
+ ]
+ return (weight_key, activation_key) in SUPPORTED_W_A
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ return activation in ["silu"]
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ return True
def supports_chunking(self) -> bool:
return True
@@ -283,82 +308,3 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_map=expert_map,
output=output,
)
-
-
-def deep_gemm_moe_fp8(
- hidden_states: torch.Tensor,
- w1: torch.Tensor,
- w2: torch.Tensor,
- w1_scale: torch.Tensor,
- w2_scale: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- inplace: bool = False,
- activation: str = "silu",
- global_num_experts: int = -1,
- expert_map: torch.Tensor | None = None,
- a1_scale: torch.Tensor | None = None,
- a2_scale: torch.Tensor | None = None,
- apply_router_weight_on_input: bool = False,
-) -> torch.Tensor:
- """
- This function computes a a8w8-quantized Mixture of Experts (MoE) layer
- using two sets of quantized weights, w1_q and w2_q, and top-k gating
- mechanism. The matrix multiplications are implemented with DeepGemm
- grouped gemm.
-
- Parameters:
- - hidden_states (torch.Tensor): The input tensor to the MoE layer.
- Shape: [M, K]
- - w1 (torch.Tensor): The first set of fp8 quantized expert weights.
- Shape: [num_experts, K, 2N] (the weights are passed transposed)
- - w2 (torch.Tensor): The second set of fp8 quantized expert weights.
- Shape: [num_experts, N, K] (the weights are passed transposed)
- - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
- Shape: [num_experts] or [num_experts, 2N]
- - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
- Shape: [num_experts] or [num_experts, K]
- - topk_weights (torch.Tensor): The weights of each token->expert mapping.
- - topk_ids (torch.Tensor): The token->expert mapping for topk_weights.
- - inplace (bool): If True, perform the operation in-place.
- Defaults to False.
- - activation (str): The activation function to apply after the first
- MoE layer.
- - global_num_experts (int): The total number of experts in the global
- expert space.
- - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
- from the global expert space to the local expert space of the expert
- parallel shard.
- - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
- Shape: scalar or [M]
- - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
- quantize the intermediate result between the gemms.
- Shape: scalar or [M]
-
- Returns:
- - torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
- """
- quant_config = fp8_w8a8_moe_quant_config(
- w1_scale=w1_scale,
- w2_scale=w2_scale,
- a1_scale=a1_scale,
- a2_scale=a2_scale,
- block_shape=get_mk_alignment_for_contiguous_layout(),
- )
-
- fn = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(),
- DeepGemmExperts(quant_config),
- )
- return fn(
- hidden_states,
- w1,
- w2,
- topk_weights,
- topk_ids,
- inplace=inplace,
- activation=activation,
- global_num_experts=global_num_experts,
- expert_map=expert_map,
- apply_router_weight_on_input=apply_router_weight_on_input,
- )
diff --git a/vllm/model_executor/layers/fused_moe/fallback.py b/vllm/model_executor/layers/fused_moe/fallback.py
index 455639214..07e5b8005 100644
--- a/vllm/model_executor/layers/fused_moe/fallback.py
+++ b/vllm/model_executor/layers/fused_moe/fallback.py
@@ -6,6 +6,8 @@ from abc import ABC, abstractmethod
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
+from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
@@ -16,18 +18,78 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
experts: mk.FusedMoEPermuteExpertsUnpermute,
fallback_experts: mk.FusedMoEPermuteExpertsUnpermute,
):
- super().__init__(experts.quant_config)
+ super().__init__(
+ moe_config=experts.moe_config, quant_config=experts.quant_config
+ )
self.fallback_experts = fallback_experts
self.experts = experts
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- assert (
- self.fallback_experts.activation_formats == self.experts.activation_formats
+ @staticmethod
+ def get_clses() -> tuple[
+ type[mk.FusedMoEPermuteExpertsUnpermute],
+ type[mk.FusedMoEPermuteExpertsUnpermute],
+ ]:
+ """
+ Get the cls for the experts and fallback experts.
+
+ Subclasses should implement this method, so that
+ we have a consistent way to call the _supports_*
+ class methods below.
+ """
+ raise NotImplementedError(
+ "Subclasses must return the cls for the experts and fallback experts."
)
- return self.fallback_experts.activation_formats
+
+ @classmethod
+ def activation_format(
+ cls: type["FallbackExperts"],
+ ) -> mk.FusedMoEActivationFormat:
+ experts_cls, fallback_cls = cls.get_clses()
+ assert experts_cls.activation_format() == fallback_cls.activation_format()
+ return experts_cls.activation_format()
+
+ @classmethod
+ def _supports_current_device(cls) -> bool:
+ experts_cls, fallback_cls = cls.get_clses()
+ return (
+ experts_cls._supports_current_device()
+ and fallback_cls._supports_current_device()
+ )
+
+ @classmethod
+ def _supports_no_act_and_mul(cls) -> bool:
+ experts_cls, fallback_cls = cls.get_clses()
+ return (
+ experts_cls._supports_no_act_and_mul()
+ and fallback_cls._supports_no_act_and_mul()
+ )
+
+ @classmethod
+ def _supports_quant_scheme(
+ cls,
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ experts_cls, fallback_cls = cls.get_clses()
+ return experts_cls._supports_quant_scheme(
+ weight_key, activation_key
+ ) and fallback_cls._supports_quant_scheme(weight_key, activation_key)
+
+ @classmethod
+ def _supports_activation(cls, activation: str) -> bool:
+ experts_cls, fallback_cls = cls.get_clses()
+ return experts_cls._supports_activation(
+ activation
+ ) and fallback_cls._supports_activation(activation)
+
+ @classmethod
+ def _supports_parallel_config(
+ cls, moe_parallel_config: FusedMoEParallelConfig
+ ) -> bool:
+ experts_cls, fallback_cls = cls.get_clses()
+ return experts_cls._supports_parallel_config(
+ moe_parallel_config
+ ) and fallback_cls._supports_parallel_config(moe_parallel_config)
def supports_chunking(self) -> bool:
assert (
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
index 1651f3530..036ee2a2e 100644
--- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
+++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
@@ -6,13 +6,22 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.logger import init_logger
-from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
+ FusedMoEQuantConfig,
+)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ kNvfp4Dynamic,
+ kNvfp4Static,
+)
+from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
flashinfer_cutedsl_grouped_gemm_nt_masked,
- has_flashinfer_cutedsl_grouped_gemm_nt_masked,
scaled_fp4_grouped_quantize,
silu_and_mul_scaled_nvfp4_experts_quantize,
)
@@ -20,54 +29,54 @@ from vllm.utils.flashinfer import (
logger = init_logger(__name__)
-def is_valid_flashinfer_cutedsl_fused_moe(
- hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor
-) -> bool:
- """
- Check if the given problem size is supported by the FlashInfer CuteDSL MoE
- kernel.
- """
- if not has_flashinfer_cutedsl_grouped_gemm_nt_masked():
- logger.debug_once(
- "FlashInferCuteDSLExperts disabled: "
- "flashinfer_cutedsl_fused_moe not available."
- )
- return False
- # Data type checks
- if (
- w1.dtype != torch.uint8
- or w2.dtype != torch.uint8
- or hidden_states.dtype not in [torch.float32, torch.float16, torch.bfloat16]
- ):
- logger.debug_once(
- "FlashInferCuteDSLExperts disabled: w1/w2 must be torch.uint8 "
- f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be "
- f"float32, float16, or bfloat16 (got {hidden_states.dtype})."
- )
- return False
- return True
-
-
class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
- out_dtype: torch.dtype,
+ moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
+ max_num_tokens: int,
+ num_dispatchers: int,
):
- super().__init__(quant_config)
+ super().__init__(
+ moe_config=moe_config,
+ quant_config=quant_config,
+ max_num_tokens=max_num_tokens,
+ num_dispatchers=num_dispatchers,
+ )
assert quant_config.quant_dtype == "nvfp4", (
"Only nvfp4 quantization are currently supported."
)
- self.out_dtype = out_dtype
+ self.out_dtype = moe_config.in_dtype
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.BatchedExperts,
- mk.FusedMoEActivationFormat.BatchedExperts,
- )
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.BatchedExperts
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ return current_platform.is_device_capability_family(100)
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ return False
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ SUPPORTED_W_A = [
+ (kNvfp4Static, kNvfp4Dynamic),
+ ]
+ return (weight_key, activation_key) in SUPPORTED_W_A
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ return activation in ["silu"]
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ return True
def supports_expert_map(self) -> bool:
return False
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
index c08894b81..8d5985875 100644
--- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
+++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
@@ -5,13 +5,22 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
-from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
-from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
- create_flashinfer_prepare_finalize,
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEParallelConfig,
+ FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ kFp8Dynamic128Sym,
+ kFp8Static128BlockSym,
+ kFp8StaticTensorSym,
+ kNvfp4Dynamic,
+ kNvfp4Static,
+)
+from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
flashinfer_cutlass_fused_moe,
has_flashinfer_cutlass_fused_moe,
@@ -50,40 +59,100 @@ def is_valid_flashinfer_cutlass_fused_moe(
class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
- out_dtype: torch.dtype,
+ moe_config: mk.FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
- ep_rank: int = 0,
- ep_size: int = 1,
- tp_rank: int = 0,
- tp_size: int = 1,
- use_dp: bool = False,
- use_deepseek_fp8_block_scale: bool = False,
):
- super().__init__(quant_config)
+ super().__init__(moe_config, quant_config)
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
"Only nvfp4, fp8, bfloat16 and"
" float16 quantization are currently supported."
)
- self.ep_rank = ep_rank
- self.ep_size = ep_size
- self.tp_rank = tp_rank
- self.tp_size = tp_size
- self.out_dtype = out_dtype
- self.use_dp = use_dp
+ self.ep_rank = moe_config.moe_parallel_config.ep_rank
+ self.ep_size = moe_config.moe_parallel_config.ep_size
+ self.tp_rank = moe_config.moe_parallel_config.tp_rank
+ self.tp_size = moe_config.moe_parallel_config.tp_size
+ self.out_dtype = moe_config.in_dtype
+ self.use_dp = moe_config.moe_parallel_config.dp_size > 1
# Enables DeepSeek-style FP8 block-scale path:
# - pass per-block weight scales to the kernel
# - skip input activation quantization (kernel applies scaling)
- self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
+ self.use_deepseek_fp8_block_scale = quant_config.is_block_quantized
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
+ @staticmethod
+ def expects_unquantized_inputs(
+ moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig
+ ) -> bool:
+ # NVFP4 TP kernels and FP8 block-quantized kernels apply
+ # input quantization inside FusedMoEPermuteExpertsUnpermute.
return (
- mk.FusedMoEActivationFormat.Standard,
- mk.FusedMoEActivationFormat.Standard,
+ quant_config.use_nvfp4_w4a4
+ and not moe_config.moe_parallel_config.use_all2all_kernels
+ ) or (quant_config.use_fp8_w8a8 and quant_config.is_block_quantized)
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ return (
+ current_platform.is_cuda()
+ and (
+ current_platform.is_device_capability((9, 0))
+ or current_platform.is_device_capability_family(100)
+ )
+ and has_flashinfer_cutlass_fused_moe()
)
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ return True
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ # The following are supported by FlashInferExperts:
+ # * unquantized
+ # * fp8 static per-tensor on 9.0+
+ # * fp8 block on 9.0
+ # * nvfp4 on 10.0+
+
+ p = current_platform
+ scheme = (weight_key, activation_key)
+ return (
+ (
+ scheme
+ in [
+ (None, None),
+ (kFp8StaticTensorSym, kFp8StaticTensorSym),
+ ]
+ )
+ or (
+ (scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym))
+ and (p.is_device_capability((9, 0)))
+ )
+ or (
+ (scheme == (kNvfp4Static, kNvfp4Dynamic))
+ and (p.is_device_capability_family(100))
+ )
+ )
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ return activation in ["silu", "relu2_no_mul"]
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ # FLASHINFER_CUTLASS currently uses its down P/F, which does not
+ # work with SP. This will be removed in follow up after we get
+ # rid of the FlashInfer specific P/F function.
+ return (
+ moe_parallel_config.dp_size == 1
+ or moe_parallel_config.dp_size == moe_parallel_config.ep_size
+ )
+
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
+
def supports_expert_map(self) -> bool:
return False
@@ -231,85 +300,3 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# No support for LoRA in flashinfer_cutlass_fused_moe.
# See TODOs in flashinfer functions runMoe and runMoeMinLantency.
raise NotImplementedError("LoRA is not supported for flashinfer_cutlass_moe")
-
-
-def flashinfer_cutlass_moe_fp4(
- hidden_states: torch.Tensor,
- w1: torch.Tensor,
- w2: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- quant_config: FusedMoEQuantConfig,
- inplace: bool = False,
- activation: str = "silu",
- global_num_experts: int = -1,
- expert_map: torch.Tensor | None = None,
- apply_router_weight_on_input: bool = False,
-) -> torch.Tensor:
- fused_experts = mk.FusedMoEModularKernel(
- create_flashinfer_prepare_finalize(
- use_dp=False, use_nvfp4=True, enable_alltoallv=False
- ),
- FlashInferExperts(
- out_dtype=hidden_states.dtype,
- quant_config=quant_config,
- use_dp=False,
- ),
- )
-
- return fused_experts(
- hidden_states=hidden_states,
- w1=w1,
- w2=w2,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- inplace=inplace,
- activation=activation,
- global_num_experts=global_num_experts,
- expert_map=expert_map,
- apply_router_weight_on_input=apply_router_weight_on_input,
- )
-
-
-def flashinfer_cutlass_moe(
- hidden_states: torch.Tensor,
- w1: torch.Tensor,
- w2: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- quant_config: FusedMoEQuantConfig,
- inplace: bool = False,
- activation: str = "silu",
- global_num_experts: int = -1,
- expert_map: torch.Tensor | None = None,
- apply_router_weight_on_input: bool = False,
- tp_rank: int = 0,
- tp_size: int = 1,
- ep_rank: int = 0,
- ep_size: int = 1,
- use_dp: bool = False,
-) -> torch.Tensor:
- fused_experts = mk.FusedMoEModularKernel(
- create_flashinfer_prepare_finalize(use_dp=use_dp),
- FlashInferExperts(
- out_dtype=hidden_states.dtype,
- quant_config=quant_config,
- tp_rank=tp_rank,
- tp_size=tp_size,
- ep_rank=ep_rank,
- ep_size=ep_size,
- ),
- )
-
- return fused_experts(
- hidden_states=hidden_states,
- w1=w1,
- w2=w2,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- inplace=inplace,
- activation=activation,
- global_num_experts=global_num_experts,
- expert_map=expert_map,
- apply_router_weight_on_input=apply_router_weight_on_input,
- )
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
index 3bb5a23ab..e5d8a7ace 100644
--- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
@@ -3,7 +3,12 @@
import torch
-from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
+ RoutingMethodType,
+)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
calculate_tile_tokens_dim,
@@ -11,8 +16,107 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ kFp8Dynamic128Sym,
+ kFp8Static128BlockSym,
+ kFp8StaticTensorSym,
+)
+from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
+#
+# Methods used by the oracle for kernel selection.
+#
+
+
+def _supports_current_device() -> bool:
+ """Supports only Blackwell-family GPUs."""
+ p = current_platform
+ # Add check flashinfer trtllm is available
+ return p.is_cuda() and p.is_device_capability_family(100)
+
+
+def _supports_no_act_and_mul() -> bool:
+ """Does not support non-gated MoE (i.e. Nanotron-Mini)."""
+ return False
+
+
+def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+) -> bool:
+ """Supports Fp8 per-tensor and Fp8 block."""
+ SUPPORTED_W_A = [
+ (kFp8Static128BlockSym, kFp8Dynamic128Sym),
+ (kFp8StaticTensorSym, kFp8StaticTensorSym),
+ ]
+ return (weight_key, activation_key) in SUPPORTED_W_A
+
+
+def _supports_activation(activation: str) -> bool:
+ """Supports silu activation only."""
+ return activation in ["silu"]
+
+
+def _supports_routing_method(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ routing_method: RoutingMethodType,
+) -> bool:
+ """Monolithic kernels need to express router support."""
+ if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
+ # NOTE(rob): potentially allow others here. This is a conservative list.
+ return routing_method in [
+ RoutingMethodType.DeepSeekV3,
+ RoutingMethodType.Renormalize,
+ RoutingMethodType.RenormalizeNaive,
+ ]
+ elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
+ # NOTE(rob): kernel requires Llama4.
+ return routing_method == RoutingMethodType.Llama4
+
+ else:
+ raise ValueError("Unsupported quantization scheme.")
+
+
+def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ """Supports TRTLLM Kernel does not support EPLB."""
+ return not moe_parallel_config.enable_eplb
+
+
+def is_supported_config_trtllm(
+ moe_config: FusedMoEConfig,
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ activation_format: mk.FusedMoEActivationFormat,
+) -> tuple[bool, str | None]:
+ """
+ This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
+ """
+
+ def _make_reason(reason: str) -> str:
+ return f"kernel does not support {reason}"
+
+ if not _supports_current_device():
+ return False, _make_reason("current device")
+ elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
+ return False, _make_reason("no act_and_mul MLP layer")
+ elif not _supports_activation(moe_config.activation):
+ return False, _make_reason(f"{moe_config.activation} activation")
+ elif not _supports_quant_scheme(weight_key, activation_key):
+ return False, _make_reason("quantization scheme")
+ elif not _supports_parallel_config(moe_config.moe_parallel_config):
+ return False, _make_reason("parallel config")
+ elif not _supports_routing_method(
+ weight_key, activation_key, moe_config.routing_method
+ ):
+ return False, _make_reason("routing method")
+ elif activation_format != mk.FusedMoEActivationFormat.Standard:
+ return False, _make_reason("activation format")
+
+ return True, None
+
def flashinfer_fused_moe_blockscale_fp8(
routing_logits: torch.Tensor,
diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py
index fb9346439..8e45c0e41 100644
--- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py
@@ -5,7 +5,11 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
-from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
+ FusedMoEQuantConfig,
+)
from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
@@ -17,7 +21,17 @@ from vllm.model_executor.layers.fused_moe.utils import (
normalize_batched_scales_shape,
normalize_scales_shape,
)
-from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ group_broadcast,
+ kFp8Dynamic128Sym,
+ kFp8DynamicTensorSym,
+ kFp8DynamicTokenSym,
+ kFp8Static128BlockSym,
+ kFp8StaticChannelSym,
+ kFp8StaticTensorSym,
+)
+from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
@@ -633,25 +647,62 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
+ moe_config: FusedMoEConfig,
+ quant_config: FusedMoEQuantConfig,
max_num_tokens: int,
num_dispatchers: int,
- quant_config: FusedMoEQuantConfig,
):
- super().__init__(quant_config)
+ super().__init__(
+ moe_config=moe_config,
+ quant_config=quant_config,
+ max_num_tokens=max_num_tokens,
+ num_dispatchers=num_dispatchers,
+ )
assert not self.quant_config.use_int8_w8a8, "NYI"
assert not self.quant_config.use_int8_w8a16, "NYI"
assert not self.quant_config.use_int4_w4a16, "NYI"
assert self.quant_config.ocp_mx_scheme is None, "NYI"
- self.max_num_tokens = max_num_tokens
- self.num_dispatchers = num_dispatchers
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.BatchedExperts,
- mk.FusedMoEActivationFormat.BatchedExperts,
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.BatchedExperts
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ raise NotImplementedError(
+ "NaiveBatchedExperts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ raise NotImplementedError(
+ "NaiveBatchedExperts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ raise NotImplementedError(
+ "NaiveBatchedExperts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ raise NotImplementedError(
+ "NaiveBatchedExperts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ raise NotImplementedError(
+ "NaiveBatchedExperts is not yet used by an Oracle. "
+ "This method should not be called."
)
def supports_chunking(self) -> bool:
@@ -675,6 +726,8 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
+ assert self.num_dispatchers is not None
+ assert self.max_num_tokens is not None
num_dp = self.num_dispatchers
num_experts = local_num_experts
workspace13 = (num_experts, self.max_num_tokens * num_dp, K)
@@ -826,29 +879,69 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
+ moe_config: FusedMoEConfig,
+ quant_config: FusedMoEQuantConfig,
max_num_tokens: int,
num_dispatchers: int,
- quant_config: FusedMoEQuantConfig,
):
- super().__init__(quant_config)
+ super().__init__(
+ moe_config=moe_config,
+ quant_config=quant_config,
+ max_num_tokens=max_num_tokens,
+ num_dispatchers=num_dispatchers,
+ )
assert not self.quant_config.use_int8_w8a8, "NYI"
assert not self.quant_config.use_int8_w8a16, "NYI"
assert not self.quant_config.use_int4_w4a16, "NYI"
assert self.quant_config.ocp_mx_scheme is None, "NYI"
- assert max_num_tokens > 0
- assert num_dispatchers > 0
- self.max_num_tokens = max_num_tokens
- self.num_dispatchers = num_dispatchers
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.BatchedExperts,
- mk.FusedMoEActivationFormat.BatchedExperts,
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.BatchedExperts
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ return current_platform.is_cuda_alike()
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ return False
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ p = current_platform
+ device_supports_fp8 = (p.is_rocm() and p.rocm.on_gfx9()) or (
+ p.is_cuda() and p.has_device_capability((8, 9))
)
+ SUPPORTED_W_A_FP8 = [
+ (kFp8Static128BlockSym, kFp8Dynamic128Sym),
+ (kFp8StaticChannelSym, kFp8DynamicTokenSym),
+ (kFp8StaticTensorSym, kFp8StaticTensorSym),
+ (kFp8StaticTensorSym, kFp8DynamicTensorSym),
+ ]
+ return (weight_key, activation_key) == (None, None) or (
+ device_supports_fp8 and (weight_key, activation_key) in SUPPORTED_W_A_FP8
+ )
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ return activation in [
+ "silu",
+ "gelu",
+ "swigluoai",
+ "silu_no_mul",
+ "gelu_no_mul",
+ "relu2_no_mul",
+ ]
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ return True
+
def supports_chunking(self) -> bool:
return False
@@ -870,6 +963,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
+ assert self.num_dispatchers is not None
+ assert self.max_num_tokens is not None
num_dp = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = self.max_num_tokens
diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
index be9dddb87..603c8fc96 100644
--- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
@@ -8,7 +8,11 @@ import torch
import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
-from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
+ FusedMoEQuantConfig,
+)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
batched_moe_align_block_size,
moe_align_block_size,
@@ -27,6 +31,13 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_moe_intermediate_size,
marlin_quant_input,
)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ kFp8Static128BlockSym,
+ kFp8StaticChannelSym,
+ kFp8StaticTensorSym,
+ kNvfp4Static,
+)
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
@@ -522,7 +533,10 @@ def batched_fused_marlin_moe(
class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
+ moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
+ max_num_tokens: int | None = None,
+ num_dispatchers: int | None = None,
w13_g_idx: torch.Tensor | None = None,
w2_g_idx: torch.Tensor | None = None,
w13_g_idx_sort_indices: torch.Tensor | None = None,
@@ -541,7 +555,51 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
self.w13_g_idx_sort_indices = w13_g_idx_sort_indices
self.w2_g_idx_sort_indices = w2_g_idx_sort_indices
self.is_k_full = is_k_full
- super().__init__(quant_config)
+ super().__init__(
+ moe_config=moe_config,
+ quant_config=quant_config,
+ max_num_tokens=max_num_tokens,
+ num_dispatchers=num_dispatchers,
+ )
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ p = current_platform
+ return p.is_cuda() and p.has_device_capability((7, 5))
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ return False
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ # TODO(rob): add int4, mxfp4, int8 as integrations
+ # are migrated to use the oracle one-by-one.
+ SUPPORTED_W = [
+ kFp8Static128BlockSym,
+ kFp8StaticChannelSym,
+ kFp8StaticTensorSym,
+ kNvfp4Static,
+ ]
+ return weight_key in SUPPORTED_W
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ return activation in [
+ "silu",
+ "gelu",
+ "swigluoai",
+ "silu_no_mul",
+ "gelu_no_mul",
+ "relu2_no_mul",
+ ]
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ return True
@property
def quant_type_id(self) -> int:
@@ -587,38 +645,15 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
class MarlinExperts(MarlinExpertsBase):
- def __init__(
- self,
- quant_config: FusedMoEQuantConfig,
- w13_g_idx: torch.Tensor | None = None,
- w2_g_idx: torch.Tensor | None = None,
- w13_g_idx_sort_indices: torch.Tensor | None = None,
- w2_g_idx_sort_indices: torch.Tensor | None = None,
- is_k_full: bool = True,
- ):
- super().__init__(
- quant_config,
- w13_g_idx,
- w2_g_idx,
- w13_g_idx_sort_indices,
- w2_g_idx_sort_indices,
- is_k_full,
- )
-
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.Standard,
- mk.FusedMoEActivationFormat.Standard,
- )
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
def supports_chunking(self) -> bool:
return True
@@ -714,9 +749,10 @@ class MarlinExperts(MarlinExpertsBase):
class BatchedMarlinExperts(MarlinExpertsBase):
def __init__(
self,
+ moe_config: FusedMoEConfig,
+ quant_config: FusedMoEQuantConfig,
max_num_tokens: int,
num_dispatchers: int,
- quant_config: FusedMoEQuantConfig,
w13_g_idx: torch.Tensor | None = None,
w2_g_idx: torch.Tensor | None = None,
w13_g_idx_sort_indices: torch.Tensor | None = None,
@@ -724,15 +760,16 @@ class BatchedMarlinExperts(MarlinExpertsBase):
is_k_full: bool = True,
):
super().__init__(
- quant_config,
- w13_g_idx,
- w2_g_idx,
- w13_g_idx_sort_indices,
- w2_g_idx_sort_indices,
- is_k_full,
+ moe_config=moe_config,
+ quant_config=quant_config,
+ max_num_tokens=max_num_tokens,
+ num_dispatchers=num_dispatchers,
+ w13_g_idx=w13_g_idx,
+ w2_g_idx=w2_g_idx,
+ w13_g_idx_sort_indices=w13_g_idx_sort_indices,
+ w2_g_idx_sort_indices=w2_g_idx_sort_indices,
+ is_k_full=is_k_full,
)
- self.max_num_tokens = max_num_tokens
- self.num_dispatchers = num_dispatchers
def supports_expert_map(self) -> bool:
return True
@@ -740,14 +777,9 @@ class BatchedMarlinExperts(MarlinExpertsBase):
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceDelegate()
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.BatchedExperts,
- mk.FusedMoEActivationFormat.BatchedExperts,
- )
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.BatchedExperts
def supports_chunking(self) -> bool:
return False
@@ -763,9 +795,11 @@ class BatchedMarlinExperts(MarlinExpertsBase):
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
+ assert self.num_dispatchers is not None
+ assert self.max_num_tokens is not None
num_dispatchers = self.num_dispatchers
num_experts = local_num_experts
- max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens
+ max_num_tokens = self.max_num_tokens
workspace13 = (num_experts * max_num_tokens * num_dispatchers, max(K, N * 2))
workspace2 = (num_experts * max_num_tokens * num_dispatchers, N)
output = (num_experts, max_num_tokens * num_dispatchers, K)
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py
index d069d81f5..7e7d59fb9 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe.py
@@ -19,13 +19,11 @@ from vllm.model_executor.layers.batch_invariant import (
)
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
FusedMoEQuantConfig,
_get_config_dtype_str,
)
-from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
- _valid_deep_gemm,
- deep_gemm_moe_fp8,
-)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size,
)
@@ -44,9 +42,16 @@ from vllm.model_executor.layers.fused_moe.utils import (
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ kFp8Dynamic128Sym,
+ kFp8DynamicTokenSym,
+ kFp8Static128BlockSym,
+ kFp8StaticChannelSym,
+ kFp8StaticTensorSym,
+)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
-from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
logger = init_logger(__name__)
@@ -1534,66 +1539,36 @@ def fused_experts(
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
quant_config: FusedMoEQuantConfig | None = None,
- allow_deep_gemm: bool = False,
) -> torch.Tensor:
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
- # For now, disable DeepGemm for small N (<= 512) until better
- # permute/unpermute ops are available.
- # However, on B200, we use DeepGemm for all cases because they only support
- # E8M0 scale, which means we requantize the weight and input to the specific
- # scale. Fallen back to cutlass or triton for some cases would cause
- # accuracy issue.
- if (
- allow_deep_gemm
- and quant_config.use_fp8_w8a8
- and (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))
- ):
- assert quant_config is not None
- return deep_gemm_moe_fp8(
- hidden_states=hidden_states,
- w1=w1,
- w2=w2,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- inplace=inplace,
- activation=activation,
- global_num_experts=global_num_experts,
- expert_map=expert_map,
- w1_scale=quant_config.w1_scale,
- w2_scale=quant_config.w2_scale,
- a1_scale=quant_config.a1_scale,
- a2_scale=quant_config.a2_scale,
- apply_router_weight_on_input=apply_router_weight_on_input,
- )
- else:
- return dispatch_fused_experts_func(inplace)(
- hidden_states=hidden_states,
- w1=w1,
- w2=w2,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- activation=activation,
- apply_router_weight_on_input=apply_router_weight_on_input,
- use_fp8_w8a8=quant_config.use_fp8_w8a8,
- use_int8_w8a8=quant_config.use_int8_w8a8,
- use_int8_w8a16=quant_config.use_int8_w8a16,
- use_int4_w4a16=quant_config.use_int4_w4a16,
- ocp_mx_scheme=quant_config.ocp_mx_scheme,
- per_channel_quant=quant_config.per_act_token_quant,
- global_num_experts=global_num_experts,
- expert_map=expert_map,
- w1_scale=quant_config.w1_scale,
- w2_scale=quant_config.w2_scale,
- w1_zp=quant_config.w1_zp,
- w2_zp=quant_config.w2_zp,
- a1_scale=quant_config.a1_scale,
- a2_scale=quant_config.a2_scale,
- block_shape=quant_config.block_shape,
- w1_bias=quant_config.w1_bias,
- w2_bias=quant_config.w2_bias,
- )
+ return dispatch_fused_experts_func(inplace)(
+ hidden_states=hidden_states,
+ w1=w1,
+ w2=w2,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ activation=activation,
+ apply_router_weight_on_input=apply_router_weight_on_input,
+ use_fp8_w8a8=quant_config.use_fp8_w8a8,
+ use_int8_w8a8=quant_config.use_int8_w8a8,
+ use_int8_w8a16=quant_config.use_int8_w8a16,
+ use_int4_w4a16=quant_config.use_int4_w4a16,
+ ocp_mx_scheme=quant_config.ocp_mx_scheme,
+ per_channel_quant=quant_config.per_act_token_quant,
+ global_num_experts=global_num_experts,
+ expert_map=expert_map,
+ w1_scale=quant_config.w1_scale,
+ w2_scale=quant_config.w2_scale,
+ w1_zp=quant_config.w1_zp,
+ w2_zp=quant_config.w2_zp,
+ a1_scale=quant_config.a1_scale,
+ a2_scale=quant_config.a2_scale,
+ block_shape=quant_config.block_shape,
+ w1_bias=quant_config.w1_bias,
+ w2_bias=quant_config.w2_bias,
+ )
def _get_config_quant_dtype(
@@ -1924,19 +1899,53 @@ def fused_experts_impl(
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
+ moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
- super().__init__(quant_config)
+ super().__init__(moe_config, quant_config)
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.Standard,
- mk.FusedMoEActivationFormat.Standard,
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ return current_platform.is_cuda_alike()
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ return False
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ p = current_platform
+ device_supports_fp8 = (p.is_rocm() and p.rocm.on_gfx9()) or (
+ p.is_cuda() and p.has_device_capability((8, 9))
)
+ if not device_supports_fp8:
+ return (weight_key, activation_key) == (None, None)
+
+ SUPPORTED_W_A = [
+ (None, None),
+ (kFp8Static128BlockSym, kFp8Dynamic128Sym),
+ (kFp8StaticChannelSym, kFp8DynamicTokenSym),
+ (kFp8StaticTensorSym, kFp8DynamicTokenSym),
+ (kFp8StaticTensorSym, kFp8StaticTensorSym),
+ ]
+ return (weight_key, activation_key) in SUPPORTED_W_A
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ return activation in ["silu", "gelu", "swigluoai"]
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ return True
+
def supports_chunking(self) -> bool:
return True
@@ -2111,11 +2120,43 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
class TritonWNA16Experts(TritonExperts):
- def __init__(
- self,
- quant_config: FusedMoEQuantConfig,
- ):
- super().__init__(quant_config)
+ @staticmethod
+ def _supports_current_device() -> bool:
+ raise NotImplementedError(
+ "TritonWNA16Experts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ raise NotImplementedError(
+ "TritonWNA16Experts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ raise NotImplementedError(
+ "TritonWNA16Experts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ raise NotImplementedError(
+ "TritonWNA16Experts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ raise NotImplementedError(
+ "TritonWNA16Experts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
def apply(
self,
@@ -2254,10 +2295,12 @@ class TritonWNA16Experts(TritonExperts):
def modular_triton_fused_moe(
- quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None
+ moe_config: FusedMoEConfig,
+ quant_config: FusedMoEQuantConfig,
+ shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEModularKernel:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
- TritonExperts(quant_config),
+ TritonExperts(moe_config, quant_config),
shared_experts,
)
diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
index c4bc1824a..b209820cd 100644
--- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
+++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
@@ -9,12 +9,16 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
+ FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+)
from vllm.triton_utils import tl, triton
from vllm.utils.import_utils import has_triton_kernels
@@ -241,8 +245,43 @@ def make_routing_data(
class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
- def __init__(self, quant_config: FusedMoEQuantConfig):
- super().__init__(quant_config)
+ @staticmethod
+ def _supports_current_device() -> bool:
+ raise NotImplementedError(
+ "OAITritonExperts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ raise NotImplementedError(
+ "OAITritonExperts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ raise NotImplementedError(
+ "OAITritonExperts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ raise NotImplementedError(
+ "OAITritonExperts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ raise NotImplementedError(
+ "OAITritonExperts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
def supports_expert_map(self) -> bool:
return True
@@ -297,19 +336,9 @@ class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
class OAITritonExperts(BaseOAITritonExperts):
- def __init__(self, quant_config: FusedMoEQuantConfig):
- # TODO (varun) : Enable activation quantization
- assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
- super().__init__(quant_config)
-
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.Standard,
- mk.FusedMoEActivationFormat.Standard,
- )
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
def supports_chunking(self) -> bool:
return True
@@ -391,19 +420,9 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
One use case for it is to inject LoRA modules on the activation and moe_sum.
"""
- def __init__(self, quant_config: FusedMoEQuantConfig):
- # TODO (varun) : Enable activation quantization
- assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
- super().__init__(quant_config)
-
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.Standard,
- mk.FusedMoEActivationFormat.Standard,
- )
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
def supports_chunking(self) -> bool:
return True
diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py
index 4f6604530..2583a3a11 100644
--- a/vllm/model_executor/layers/fused_moe/layer.py
+++ b/vllm/model_executor/layers/fused_moe/layer.py
@@ -330,7 +330,6 @@ class FusedMoE(CustomOp):
is_sequence_parallel=False,
expert_mapping: list[tuple[str, str, int, str]] | None = None,
n_shared_experts: int | None = None,
- routing_method_type: RoutingMethodType | None = None,
router_logits_dtype: torch.dtype | None = None,
):
super().__init__()
@@ -519,10 +518,43 @@ class FusedMoE(CustomOp):
self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = activation
+ # TODO(bnell): in next PR move capture back to layer
+ capture: Callable[[torch.Tensor], None] | None = None
+ if (
+ self.vllm_config.model_config is not None
+ and self.vllm_config.model_config.enable_return_routed_experts
+ ):
+ # In dummy runs, the capturer is not initialized.
+ capturer = RoutedExpertsCapturer.get_instance()
+ if capturer is not None:
+ capture = lambda topk_ids: capturer.capture(self.layer_id, topk_ids)
+
+ self.router = create_fused_moe_router(
+ top_k=top_k,
+ global_num_experts=self.global_num_experts,
+ eplb_state=self.eplb_state,
+ renormalize=renormalize,
+ use_grouped_topk=use_grouped_topk,
+ num_expert_group=num_expert_group,
+ topk_group=topk_group,
+ custom_routing_function=custom_routing_function,
+ scoring_func=scoring_func,
+ routed_scaling_factor=routed_scaling_factor,
+ e_score_correction_bias=e_score_correction_bias,
+ num_fused_shared_experts=self.num_fused_shared_experts,
+ enable_eplb=enable_eplb,
+ # TODO(bnell): once we can construct the MK at init time, we
+ # can make this a value.
+ indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
+ capture=capture,
+ )
+ self.routing_method_type: RoutingMethodType = self.router.routing_method_type
+
self.moe_config: FusedMoEConfig = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
+ intermediate_size_per_partition=self.intermediate_size_per_partition,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=moe_in_dtype,
@@ -531,6 +563,9 @@ class FusedMoE(CustomOp):
has_bias=has_bias,
is_act_and_mul=is_act_and_mul,
is_lora_enabled=vllm_config.lora_config is not None,
+ activation=activation,
+ device=vllm_config.device_config.device,
+ routing_method=self.routing_method_type,
)
self.moe_config_use_flashinfer_cutlass_kernels = (
self.moe_config.use_flashinfer_cutlass_kernels
@@ -594,39 +629,6 @@ class FusedMoE(CustomOp):
self.batched_hidden_states: torch.Tensor | None = None
self.batched_router_logits: torch.Tensor | None = None
- # TODO(bnell): in next PR move capture back to layer
- capture: Callable[[torch.Tensor], None] | None = None
- if (
- self.vllm_config.model_config is not None
- and self.vllm_config.model_config.enable_return_routed_experts
- ):
- # In dummy runs, the capturer is not initialized.
- capturer = RoutedExpertsCapturer.get_instance()
- if capturer is not None:
- capture = lambda topk_ids: capturer.capture(self.layer_id, topk_ids)
-
- self.router = create_fused_moe_router(
- top_k=top_k,
- global_num_experts=self.global_num_experts,
- eplb_state=self.eplb_state,
- renormalize=renormalize,
- use_grouped_topk=use_grouped_topk,
- num_expert_group=num_expert_group,
- topk_group=topk_group,
- custom_routing_function=custom_routing_function,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- e_score_correction_bias=e_score_correction_bias,
- num_fused_shared_experts=self.num_fused_shared_experts,
- enable_eplb=enable_eplb,
- # TODO(bnell): once we can construct the MK at init time, we
- # can make this a value.
- indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
- routing_method_type=routing_method_type,
- capture=capture,
- )
- self.routing_method_type: RoutingMethodType = self.router.routing_method_type
-
# Note: maybe_init_modular_kernel should only be called by
# prepare_communication_buffer_for_model.
# This is called after all weight loading and post-processing, so it
diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py
index 962d0fe78..6aff6401e 100644
--- a/vllm/model_executor/layers/fused_moe/modular_kernel.py
+++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py
@@ -13,6 +13,7 @@ import vllm.envs as envs
from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
@@ -22,6 +23,9 @@ from vllm.model_executor.layers.fused_moe.utils import (
count_expert_num_tokens,
disable_inplace,
)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+)
from vllm.utils.math_utils import cdiv
from vllm.v1.worker.ubatching import (
dbo_enabled,
@@ -374,18 +378,51 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
def __init__(
self,
+ moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
+ max_num_tokens: int | None = None,
+ num_dispatchers: int | None = None,
):
"""
+ moe_config: MoE layer configuration.
quant_config: Quantization parameters for this experts instance.
"""
- self.quant_config = quant_config
+ if self.activation_format() == FusedMoEActivationFormat.Standard and (
+ max_num_tokens is not None or num_dispatchers is not None
+ ):
+ raise ValueError(
+ "max_num_tokens and num_dispatchers should only be set for "
+ "BatchedExperts activation format."
+ )
+ elif self.activation_format() == FusedMoEActivationFormat.BatchedExperts and (
+ max_num_tokens is None or num_dispatchers is None
+ ):
+ raise ValueError(
+ "max_num_tokens and num_dispatchers must be set for "
+ "BatchedExperts activation format."
+ )
- @property
+ self.moe_config = moe_config
+ self.quant_config = quant_config
+ self.max_num_tokens = max_num_tokens
+ self.num_dispatchers = num_dispatchers
+
+ @staticmethod
+ def expects_unquantized_inputs(
+ moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig
+ ) -> bool:
+ """
+ Whether or not the PrepareFinalize should defer input quantization
+ in the prepare step. If True, then the Experts kernel will
+ execute the input quantization itself.
+
+ Sample subclasses that override are AITER and FlashInfer CUTLASS.
+ """
+ return False
+
+ @staticmethod
@abstractmethod
- def activation_formats(
- self,
- ) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]:
+ def activation_format() -> FusedMoEActivationFormat:
"""
A property which is a tuple of the input and output activation formats
for the 'apply' method.
@@ -435,6 +472,78 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
return E, M, N, K, topk
+ #
+ # Various helpers for registering support for various features.
+ # Used by the oracle to select a particular kernel for a deployment.
+ #
+
+ @staticmethod
+ def is_supported_config(
+ cls: type["FusedMoEPermuteExpertsUnpermute"],
+ moe_config: FusedMoEConfig,
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ activation_format: FusedMoEActivationFormat,
+ ) -> tuple[bool, str | None]:
+ def _make_reason(reason: str) -> str:
+ return f"kernel does not support {reason}"
+
+ if not cls._supports_current_device():
+ return False, _make_reason("current device")
+ elif not (moe_config.is_act_and_mul or cls._supports_no_act_and_mul()):
+ return False, _make_reason("no act_and_mul MLP layer")
+ elif not cls._supports_activation(moe_config.activation):
+ return False, _make_reason(f"{moe_config.activation} activation")
+ elif not cls._supports_quant_scheme(weight_key, activation_key):
+ return False, _make_reason("quantization scheme")
+ elif not cls._supports_parallel_config(moe_config.moe_parallel_config):
+ return False, _make_reason("parallel config")
+ elif activation_format != cls.activation_format():
+ return False, _make_reason(f"{activation_format.value} activation format")
+ return True, None
+
+ @staticmethod
+ @abstractmethod
+ def _supports_current_device() -> bool:
+ """
+ Whether the kernel supports the current device type
+ (compute cability and current platform).
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ @abstractmethod
+ def _supports_no_act_and_mul() -> bool:
+ """
+ Whether the kernel supports act_and_mul=False, i.e.
+ non-gated MoE models like Nemotron-Nano.
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ @abstractmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ raise NotImplementedError
+
+ @staticmethod
+ @abstractmethod
+ def _supports_activation(activation: str) -> bool:
+ """
+ Whether the kernel supports a particular act function.
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ @abstractmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ """
+ Whether the kernel supports deployment in expert parallel.
+ """
+ raise NotImplementedError
+
#
# Various helpers for accessing quantization parameters from the
# quant_config.
@@ -715,12 +824,12 @@ class FusedMoEModularKernel(torch.nn.Module):
self._post_init_setup()
assert (
- prepare_finalize.activation_format == fused_experts.activation_formats[0]
+ prepare_finalize.activation_format == fused_experts.activation_format()
), (
f"{prepare_finalize.__class__.__name__}."
f"{prepare_finalize.activation_format} == "
f"{fused_experts.__class__.__name__}."
- f"{fused_experts.activation_formats[0]}"
+ f"{fused_experts.activation_format()}"
)
def _post_init_setup(self):
diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py
index 6872b542f..cdf2d291b 100644
--- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py
+++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py
@@ -14,6 +14,12 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
fp8_w8a16_moe_quant_config,
)
+from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
+ is_supported_config_trtllm,
+)
+from vllm.model_executor.layers.fused_moe.prepare_finalize import (
+ MoEPrepareAndFinalizeNoEP,
+)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
get_flashinfer_moe_backend,
@@ -26,133 +32,307 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_fp8_moe_layer_for_marlin,
)
-from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
- cutlass_group_gemm_supported,
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
)
-from vllm.platforms import current_platform
-from vllm.utils.deep_gemm import is_deep_gemm_supported
-from vllm.utils.flashinfer import has_flashinfer_moe
-from vllm.utils.import_utils import has_deep_gemm
logger = init_logger(__name__)
class Fp8MoeBackend(Enum):
- NONE = 0
- FLASHINFER_TRTLLM = 1
- FLASHINFER_CUTLASS = 2
- DEEPGEMM = 3
- MARLIN = 4
- TRITON = 5
- AITER = 6
- VLLM_CUTLASS = 7
+ NONE = "NONE"
+ FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"
+ FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS"
+ DEEPGEMM = "DEEPGEMM"
+ BATCHED_DEEPGEMM = "BATCHED_DEEPGEMM"
+ MARLIN = "MARLIN"
+ TRITON = "TRITON"
+ BATCHED_TRITON = "BATCHED_TRITON"
+ AITER = "AITER"
+ VLLM_CUTLASS = "VLLM_CUTLASS"
+ BATCHED_VLLM_CUTLASS = "BATCHED_VLLM_CUTLASS"
+
+
+def backend_to_kernel_cls(
+ backend: Fp8MoeBackend,
+) -> type[mk.FusedMoEPermuteExpertsUnpermute]:
+ if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
+ raise NotImplementedError
+
+ elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
+ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
+ FlashInferExperts,
+ )
+
+ return FlashInferExperts
+
+ elif backend == Fp8MoeBackend.DEEPGEMM:
+ from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
+ TritonOrDeepGemmExperts,
+ )
+
+ return TritonOrDeepGemmExperts
+
+ elif backend == Fp8MoeBackend.BATCHED_DEEPGEMM:
+ from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
+ BatchedDeepGemmExperts,
+ )
+
+ return BatchedDeepGemmExperts
+
+ elif backend == Fp8MoeBackend.MARLIN:
+ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
+ MarlinExperts,
+ )
+
+ return MarlinExperts
+
+ elif backend == Fp8MoeBackend.TRITON:
+ from vllm.model_executor.layers.fused_moe.fused_moe import (
+ TritonExperts,
+ )
+
+ return TritonExperts
+
+ elif backend == Fp8MoeBackend.BATCHED_TRITON:
+ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
+ BatchedTritonExperts,
+ )
+
+ return BatchedTritonExperts
+
+ elif backend == Fp8MoeBackend.AITER:
+ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
+ AiterExperts,
+ )
+
+ return AiterExperts
+
+ elif backend == Fp8MoeBackend.VLLM_CUTLASS:
+ from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import (
+ TritonOrCutlassExperts,
+ )
+
+ return TritonOrCutlassExperts
+
+ elif backend == Fp8MoeBackend.BATCHED_VLLM_CUTLASS:
+ from vllm.model_executor.layers.fused_moe.cutlass_moe import (
+ CutlassBatchedExpertsFp8,
+ )
+
+ return CutlassBatchedExpertsFp8
+
+ else:
+ raise ValueError(f"Unknown FP8 MoE backend: {backend.value}")
def select_fp8_moe_backend(
- block_quant: bool,
- tp_size: int,
- with_lora_support: bool,
- is_act_and_mul: bool,
+ config: FusedMoEConfig,
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
allow_vllm_cutlass: bool = False,
-) -> Fp8MoeBackend:
+) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
- # TODO(rob): in a future PR, we will query each mk for
- # supported features and return the mk directly, just like
- # we do for the Attention Backend.
+ k_cls: type[mk.FusedMoEPermuteExpertsUnpermute] | None = None
- if with_lora_support:
- return Fp8MoeBackend.TRITON
+ if config.is_lora_enabled:
+ return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON)
- def _make_log_backend(backend_name: str):
- return f"Using {backend_name} backend for FP8 MoE"
+ # NOTE: the kernels are selected in the following order.
+ AVAILABLE_BACKENDS = [
+ Fp8MoeBackend.AITER,
+ Fp8MoeBackend.FLASHINFER_TRTLLM,
+ Fp8MoeBackend.FLASHINFER_CUTLASS,
+ Fp8MoeBackend.DEEPGEMM,
+ Fp8MoeBackend.BATCHED_DEEPGEMM,
+ Fp8MoeBackend.VLLM_CUTLASS,
+ Fp8MoeBackend.BATCHED_VLLM_CUTLASS,
+ Fp8MoeBackend.TRITON,
+ Fp8MoeBackend.BATCHED_TRITON,
+ Fp8MoeBackend.MARLIN,
+ ]
- # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
- if (
- current_platform.is_cuda()
- and (
- current_platform.is_device_capability_family(100)
- or current_platform.is_device_capability(90)
- )
- and envs.VLLM_USE_FLASHINFER_MOE_FP8
- and has_flashinfer_moe()
- ):
- backend = get_flashinfer_moe_backend()
- if backend == FlashinferMoeBackend.TENSORRT_LLM:
- logger.info_once(_make_log_backend("FlashInfer TRTLLM"))
- if not is_act_and_mul:
- raise ValueError(
- "FlashInfer TRTLLM FP8 MoE backend only supports "
- "act_and_mul gate_up_project fusion. Please set "
- "VLLM_USE_FLASHINFER_MOE_FP8=throughput to use the "
- "FlashInfer CUTLASS backend instead."
- )
- return Fp8MoeBackend.FLASHINFER_TRTLLM
- else:
- if block_quant and current_platform.is_device_capability_family(100):
- raise ValueError(
- "FlashInfer FP8 MoE throughput backend does not "
- "support block quantization on SM100. Please use "
- "VLLM_FLASHINFER_MOE_BACKEND=latency to use the "
- "FlashInfer TRTLLM backend instead."
- )
- logger.info_once(_make_log_backend("FlashInfer CUTLASS"))
- return Fp8MoeBackend.FLASHINFER_CUTLASS
+ # NOTE(rob): We need to peak into the P/F selection to determine
+ # if we are using the batched or standard expert format, which
+ # if not ideal. Once we unify TP + DP/EP, we can select P/F first.
+ activation_format = (
+ mk.FusedMoEActivationFormat.BatchedExperts
+ if config.moe_parallel_config.use_batched_activation_format
+ else mk.FusedMoEActivationFormat.Standard
+ )
- # weight-only path for older GPUs without native FP8
- if (
- current_platform.is_cuda() and not current_platform.has_device_capability(89)
- ) or envs.VLLM_TEST_FORCE_FP8_MARLIN:
- logger.info_once(_make_log_backend("Marlin"), scope="local")
- return Fp8MoeBackend.MARLIN
-
- # Determine if we should use DeepGEMM with block-quantized weights:
- # - If explicitly set by user, respect their choice
- # - If not explicitly set (default), disable when TP size is >= 8
- moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
- if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and tp_size >= 8:
- moe_use_deep_gemm = False
- logger.info_once(
- "DeepGEMM MoE is disabled by default when TP size is >= 8. "
- "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
- scope="local",
+ def _make_log_backend(backend: Fp8MoeBackend):
+ available_backend_strs = [b.value for b in AVAILABLE_BACKENDS]
+ return (
+ f"Using {backend.value} Fp8 MoE backend out "
+ f"of potential backends: {available_backend_strs}."
)
- use_deep_gemm = envs.VLLM_USE_DEEP_GEMM
- if not is_deep_gemm_supported():
- use_deep_gemm = False
- logger.info_once(
- "DeepGEMM is disabled because the platform does not support it.",
- scope="local",
- )
-
- if use_deep_gemm and moe_use_deep_gemm and block_quant and is_act_and_mul:
- if not has_deep_gemm():
- logger.warning_once(
- "DeepGEMM backend requested but not available.", scope="local"
+ def _make_log_unsupported(backend: Fp8MoeBackend, reason: str | None) -> str:
+ if reason:
+ return (
+ f"FP8 MoE backend {backend.value} does not support the "
+ f"deployment configuration since {reason}."
+ )
+ else:
+ return (
+ f"FP8 MoE backend '{backend.value}' does not support the "
+ "deployment configuration."
)
- elif is_deep_gemm_supported():
- logger.info_once(_make_log_backend("DeepGEMM"), scope="local")
- return Fp8MoeBackend.DEEPGEMM
- if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
- logger.info_once(_make_log_backend("ROCm AITER"), scope="local")
- return Fp8MoeBackend.AITER
+ def _return_or_raise(
+ backend: Fp8MoeBackend,
+ config: FusedMoEConfig,
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ activation_format: mk.FusedMoEActivationFormat,
+ ) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]:
+ k_cls = backend_to_kernel_cls(backend)
+ supported, reason = k_cls.is_supported_config(
+ k_cls, config, weight_key, activation_key, activation_format
+ )
+ if supported:
+ logger.info_once(_make_log_backend(backend), scope="local")
+ return backend, k_cls
+ raise ValueError(_make_log_unsupported(backend, reason))
- if (
- allow_vllm_cutlass
- and not block_quant
- and cutlass_group_gemm_supported()
- and is_act_and_mul
- ):
- logger.info_once(_make_log_backend("vLLM CUTLASS"), scope="local")
- return Fp8MoeBackend.VLLM_CUTLASS
+ # Handle explicit FlashInfer FP8 configuration.
+ if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP8"):
+ if not envs.VLLM_USE_FLASHINFER_MOE_FP8:
+ # If the user rejects FlashInfer remove those backends.
+ AVAILABLE_BACKENDS.remove(Fp8MoeBackend.FLASHINFER_TRTLLM)
+ AVAILABLE_BACKENDS.remove(Fp8MoeBackend.FLASHINFER_CUTLASS)
- # default to Triton
- logger.info_once(_make_log_backend("Triton"), scope="local")
- return Fp8MoeBackend.TRITON
+ elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
+ # If user is explicit about backend, validate it.
+ fi_backend = get_flashinfer_moe_backend()
+
+ if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
+ backend = Fp8MoeBackend.FLASHINFER_TRTLLM
+ supported, reason = is_supported_config_trtllm(
+ config, weight_key, activation_key, activation_format
+ )
+ if supported:
+ logger.info_once(_make_log_backend(backend))
+ return backend, None
+ else:
+ raise ValueError(_make_log_unsupported(backend, reason))
+
+ elif fi_backend == FlashinferMoeBackend.CUTLASS:
+ backend = Fp8MoeBackend.FLASHINFER_CUTLASS
+ return _return_or_raise(
+ backend, config, weight_key, activation_key, activation_format
+ )
+
+ else:
+ assert fi_backend == FlashinferMoeBackend.CUTEDSL
+ raise ValueError("FlashInfer MaskedGEMM not supported for FP8")
+
+ else:
+ # If the user is not explicit about the backend, try both.
+ for backend in [
+ Fp8MoeBackend.FLASHINFER_TRTLLM,
+ Fp8MoeBackend.FLASHINFER_CUTLASS,
+ ]:
+ if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
+ k_cls = None
+ supported, reason = is_supported_config_trtllm(
+ config,
+ weight_key,
+ activation_key,
+ activation_format,
+ )
+ else:
+ k_cls = backend_to_kernel_cls(backend)
+ supported, reason = k_cls.is_supported_config(
+ k_cls,
+ config,
+ weight_key,
+ activation_key,
+ activation_format,
+ )
+
+ if supported:
+ logger.info_once(_make_log_backend(backend), scope="local")
+ return backend, k_cls
+ else:
+ logger.debug_once(
+ _make_log_unsupported(backend, reason), scope="local"
+ )
+
+ raise NotImplementedError(
+ "Found VLLM_USE_FLASHINFER_MOE_FP8=1, but no "
+ "FlashInfer FP8 MoE backend supports the configuration."
+ )
+
+ # Handle explicit DeepGEMM FP8 configuration.
+ if envs.is_set("VLLM_USE_DEEP_GEMM") or envs.is_set("VLLM_MOE_USE_DEEP_GEMM"):
+ if not envs.VLLM_USE_DEEP_GEMM or not envs.VLLM_MOE_USE_DEEP_GEMM:
+ AVAILABLE_BACKENDS.remove(Fp8MoeBackend.DEEPGEMM)
+ AVAILABLE_BACKENDS.remove(Fp8MoeBackend.BATCHED_DEEPGEMM)
+ else:
+ backend = (
+ Fp8MoeBackend.DEEPGEMM
+ if activation_format == mk.FusedMoEActivationFormat.Standard
+ else Fp8MoeBackend.BATCHED_DEEPGEMM
+ )
+ return _return_or_raise(
+ backend, config, weight_key, activation_key, activation_format
+ )
+
+ # Handle explicit MARLIN FP8 configuration.
+ if envs.VLLM_TEST_FORCE_FP8_MARLIN:
+ backend = Fp8MoeBackend.MARLIN
+ return _return_or_raise(
+ backend, config, weight_key, activation_key, activation_format
+ )
+
+ # Handle explicit AITER FP8 configuration.
+ if envs.is_set("VLLM_ROCM_USE_AITER") or envs.is_set("VLLM_ROCM_USE_AITER_MOE"):
+ if not envs.VLLM_ROCM_USE_AITER or not envs.VLLM_ROCM_USE_AITER_MOE:
+ AVAILABLE_BACKENDS.remove(Fp8MoeBackend.AITER)
+ else:
+ backend = Fp8MoeBackend.AITER
+ return _return_or_raise(
+ backend, config, weight_key, activation_key, activation_format
+ )
+
+ if not allow_vllm_cutlass:
+ AVAILABLE_BACKENDS.remove(Fp8MoeBackend.VLLM_CUTLASS)
+ AVAILABLE_BACKENDS.remove(Fp8MoeBackend.BATCHED_VLLM_CUTLASS)
+
+ # Select kernels in order of backend.
+ for backend in AVAILABLE_BACKENDS:
+ if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
+ k_cls = None
+ supported, reason = is_supported_config_trtllm(
+ config,
+ weight_key,
+ activation_key,
+ activation_format,
+ )
+ else:
+ k_cls = backend_to_kernel_cls(backend)
+ supported, reason = k_cls.is_supported_config(
+ k_cls,
+ config,
+ weight_key,
+ activation_key,
+ activation_format,
+ )
+
+ if supported:
+ logger.info_once(_make_log_backend(backend), scope="local")
+ return backend, k_cls
+ else:
+ logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
+
+ raise NotImplementedError(
+ "No FP8 MoE backend supports the deployment configuration."
+ )
def convert_to_fp8_moe_kernel_format(
@@ -166,7 +346,7 @@ def convert_to_fp8_moe_kernel_format(
w2_input_scale: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
block_quant = hasattr(layer, "weight_block_size")
- if fp8_backend == Fp8MoeBackend.DEEPGEMM:
+ if fp8_backend in [Fp8MoeBackend.DEEPGEMM, Fp8MoeBackend.BATCHED_DEEPGEMM]:
assert block_quant
w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_deepgemm(
w13,
@@ -199,6 +379,14 @@ def convert_to_fp8_moe_kernel_format(
w2_input_scale=w2_input_scale,
is_trtllm=(fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM),
)
+ else:
+ if fp8_backend not in [
+ Fp8MoeBackend.TRITON,
+ Fp8MoeBackend.BATCHED_TRITON,
+ Fp8MoeBackend.VLLM_CUTLASS,
+ Fp8MoeBackend.BATCHED_VLLM_CUTLASS,
+ ]:
+ raise ValueError(f"Unsupported FP8 MoE backend: {fp8_backend.value}")
return w13, w2, w13_scale, w2_scale
@@ -210,6 +398,8 @@ def make_fp8_moe_quant_config(
a1_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
block_shape: list[int] | None = None,
+ per_act_token_quant: bool = False,
+ per_out_ch_quant: bool = False,
) -> FusedMoEQuantConfig | None:
"""
Create FusedMoEQuantConfig for the specifed FP8 Backend.
@@ -262,102 +452,76 @@ def make_fp8_moe_quant_config(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
+ per_act_token_quant=per_act_token_quant,
+ per_out_ch_quant=per_out_ch_quant,
)
+def make_fp8_moe_kernel_for_mkm(
+ moe_config: FusedMoEConfig,
+ quant_config: FusedMoEQuantConfig,
+ experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
+ prepare_finalize: mk.FusedMoEPrepareAndFinalize,
+) -> mk.FusedMoEPermuteExpertsUnpermute:
+ if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
+ max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
+ assert max_num_tokens_per_rank is not None
+ experts = experts_cls(
+ moe_config=moe_config,
+ quant_config=quant_config,
+ max_num_tokens=max_num_tokens_per_rank,
+ num_dispatchers=prepare_finalize.num_dispatchers(),
+ )
+ else:
+ experts = experts_cls(
+ moe_config=moe_config,
+ quant_config=quant_config,
+ )
+
+ logger.debug_once("Using %s", experts.__class__.__name__)
+ return experts
+
+
def make_fp8_moe_kernel(
- layer: torch.nn.Module,
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
fp8_backend: Fp8MoeBackend,
+ experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
) -> tuple[mk.FusedMoEModularKernel, bool]:
- # Delayed import is required since the oracle is imported
- # by CPU backends which cannot import all of these experts.
- # TODO: update the experts to make this not happen.
- from vllm.model_executor.layers.fused_moe.prepare_finalize import (
- MoEPrepareAndFinalizeNoEP,
+ # TODO(rob): unify after we merge tp and dp/ep.
+ if (
+ moe_config.moe_parallel_config.use_all2all_kernels
+ and moe_config.moe_parallel_config.all2all_backend
+ not in ["allgather_reducescatter", "naive"]
+ ):
+ raise ValueError(
+ "Fp8 Oracle should not create non-naive A2A P/F. "
+ "This should happen via the ModularKernelMethod."
+ )
+
+ # Create Prepare/Finalize.
+ prepare_finalize = MoEPrepareAndFinalizeNoEP(
+ defer_input_quant=experts_cls.expects_unquantized_inputs(
+ moe_config, moe_quant_config
+ ),
)
- # NOTE(rob): this is a WIP refactor. We are first migrating
- # all of the kernels in the TP case to use mk. Once this is
- # done, then we will initialzie the TP case and DP/EP case
- # via the same code path (i.e. via maybe_init_modular_kernel).
- # NOTE(rob): in progress migrating all into this format.
- use_inplace = True
- if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
- from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
- FlashInferExperts,
- )
+ # Create Experts.
+ experts = experts_cls(
+ moe_config=moe_config,
+ quant_config=moe_quant_config,
+ )
- kernel = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(
- defer_input_quant=moe_quant_config.is_block_quantized
- ),
- FlashInferExperts(
- out_dtype=layer.orig_dtype,
- quant_config=moe_quant_config,
- ep_rank=moe_config.ep_rank,
- ep_size=moe_config.ep_size,
- tp_rank=moe_config.tp_rank,
- tp_size=moe_config.tp_size,
- use_dp=(moe_config.dp_size > 1),
- use_deepseek_fp8_block_scale=moe_quant_config.is_block_quantized,
- ),
- )
- use_inplace = False
+ # NOTE(rob): we only want the mk to control the shared_expert
+ # if using all2all (for SBO). bnell is making this explict in
+ # the new MoE runner class.
+ kernel = mk.FusedMoEModularKernel(
+ prepare_finalize,
+ experts,
+ shared_experts=None,
+ moe_parallel_config=moe_config.moe_parallel_config,
+ )
- elif fp8_backend == Fp8MoeBackend.AITER:
- from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
- AiterExperts,
- )
-
- kernel = mk.FusedMoEModularKernel(
- # TODO: make defer_input_quant an attr of the AiterExperts
- MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
- AiterExperts(quant_config=moe_quant_config),
- )
- elif fp8_backend == Fp8MoeBackend.MARLIN:
- from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
- MarlinExperts,
- )
-
- kernel = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(),
- MarlinExperts(quant_config=moe_quant_config),
- )
- elif fp8_backend == Fp8MoeBackend.VLLM_CUTLASS:
- from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import (
- TritonOrCutlassExperts,
- )
-
- kernel = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(),
- TritonOrCutlassExperts(
- out_dtype=moe_config.in_dtype,
- e=layer.local_num_experts,
- n=layer.intermediate_size_per_partition,
- k=layer.hidden_size,
- device=layer.w13_weight.device,
- quant_config=moe_quant_config,
- ),
- )
- elif fp8_backend == Fp8MoeBackend.DEEPGEMM:
- from vllm.model_executor.layers.fused_moe import (
- TritonOrDeepGemmExperts,
- )
-
- kernel = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(),
- TritonOrDeepGemmExperts(quant_config=moe_quant_config),
- )
- else:
- from vllm.model_executor.layers.fused_moe.fused_moe import (
- TritonExperts,
- )
-
- assert fp8_backend == Fp8MoeBackend.TRITON
- kernel = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(),
- TritonExperts(quant_config=moe_quant_config),
- )
- return kernel, use_inplace
+ # TODO(rob): update inplace logic to be part of the kernel.
+ inplace = fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS
+ return kernel, inplace
diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
index f2d69cf09..897631e23 100644
--- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
+++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
@@ -14,21 +14,11 @@ from vllm.model_executor.layers.fused_moe.config import (
nvfp4_moe_quant_config,
nvfp4_w4a16_moe_quant_config,
)
-from vllm.model_executor.layers.fused_moe.cutlass_moe import (
- CutlassExpertsFp4,
-)
-from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
- FlashInferExperts,
-)
-from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
- MarlinExperts,
-)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
- is_flashinfer_fp4_cutedsl_moe_available,
- is_flashinfer_fp4_cutlass_moe_available,
+ is_supported_config_trtllm,
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
@@ -36,27 +26,26 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
get_flashinfer_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
- is_fp4_marlin_supported,
prepare_nvfp4_moe_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
- cutlass_fp4_supported,
+ QuantKey,
)
logger = init_logger(__name__)
class NvFp4MoeBackend(Enum):
- FLASHINFER_CUTLASS = "FlashInfer CUTLASS"
- FLASHINFER_TRTLLM = "FlashInfer TRTLLM"
- FLASHINFER_CUTEDSL = "FlashInfer CUTEDSL"
- VLLM_CUTLASS = "vLLM CUTASS"
- MARLIN = "vLLM MARLIN"
+ FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"
+ FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS"
+ FLASHINFER_CUTEDSL = "FLASHINFER_CUTEDSL"
+ VLLM_CUTLASS = "VLLM_CUTLASS"
+ MARLIN = "MARLIN"
FLASHINFER_NVFP4_MOE_BACKENDS = [
- NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.FLASHINFER_TRTLLM,
+ NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
]
@@ -72,44 +61,208 @@ def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool:
# of all experts in Expert Parallel Mode when all experts are not
# on the same rank.
- return backend in [
- NvFp4MoeBackend.FLASHINFER_CUTLASS,
+ return backend in FLASHINFER_NVFP4_MOE_BACKENDS
+
+
+def backend_to_kernel_cls(
+ backend: NvFp4MoeBackend,
+) -> type[mk.FusedMoEPermuteExpertsUnpermute]:
+ if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
+ raise NotImplementedError(
+ "FLASHINFER_TRTLLM doesn't support Modular Kernel Interface"
+ )
+
+ elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
+ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
+ FlashInferExperts,
+ )
+
+ return FlashInferExperts
+
+ elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
+ from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
+ FlashInferCuteDSLExperts,
+ )
+
+ return FlashInferCuteDSLExperts
+
+ elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
+ from vllm.model_executor.layers.fused_moe.cutlass_moe import (
+ CutlassExpertsFp4,
+ )
+
+ return CutlassExpertsFp4
+
+ elif backend == NvFp4MoeBackend.MARLIN:
+ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
+ MarlinExperts,
+ )
+
+ return MarlinExperts
+ else:
+ raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}")
+
+
+def select_nvfp4_moe_backend(
+ config: FusedMoEConfig,
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]:
+ """
+ Select the primary NvFP4 MoE backend
+ Note: Shape-specific fallbacks may still occur at runtime.
+ """
+
+ # NOTE: the kernels are selected in the following order.
+ AVAILABLE_BACKENDS = [
NvFp4MoeBackend.FLASHINFER_TRTLLM,
+ NvFp4MoeBackend.FLASHINFER_CUTEDSL,
+ NvFp4MoeBackend.FLASHINFER_CUTLASS,
+ NvFp4MoeBackend.VLLM_CUTLASS,
+ NvFp4MoeBackend.MARLIN,
]
+ # NOTE(rob): this is kind of a hack. We need to peak into
+ # the prepare-finalize selection to determine if we are using
+ # the batched or standard expert format.
+ use_batched = (
+ config.moe_parallel_config.use_deepep_ll_kernels
+ or config.moe_parallel_config.use_pplx_kernels
+ )
+ activation_format = (
+ mk.FusedMoEActivationFormat.BatchedExperts
+ if use_batched
+ else mk.FusedMoEActivationFormat.Standard
+ )
-def select_nvfp4_moe_backend() -> NvFp4MoeBackend:
def _make_log_backend(backend: NvFp4MoeBackend):
- return f"Using {backend.value} backend for NvFp4 MoE"
-
- if cutlass_fp4_supported() and not envs.VLLM_TEST_FORCE_FP8_MARLIN:
- allow_flashinfer = (
- is_flashinfer_fp4_cutlass_moe_available()
- or is_flashinfer_fp4_cutedsl_moe_available()
+ available_backend_strs = [b.value for b in AVAILABLE_BACKENDS]
+ return (
+ f"Using '{backend.value}' NvFp4 MoE backend out "
+ f"of potential backends: {available_backend_strs}."
)
- if allow_flashinfer and envs.VLLM_USE_FLASHINFER_MOE_FP4:
- backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()]
+
+ def _make_log_unsupported(backend: NvFp4MoeBackend, reason: str | None) -> str:
+ if reason:
+ return (
+ f"NvFp4 MoE backend '{backend.value}' does not support the "
+ f"deployment configuration since {reason}."
+ )
else:
- backend = NvFp4MoeBackend.VLLM_CUTLASS
- elif is_fp4_marlin_supported():
- backend = NvFp4MoeBackend.MARLIN
- else:
- raise ValueError("No NvFp4 kernel backend available for NvFp4 MoE.")
+ return (
+ f"NvFp4 MoE backend '{backend.value}' does not support the "
+ "deployment configuration."
+ )
- # Log warning if FI backend requested but not available.
- if (
- backend not in FLASHINFER_NVFP4_MOE_BACKENDS
- and envs.VLLM_USE_FLASHINFER_MOE_FP4
- ):
- logger.warning_once(
- "Requested FlashInfer backend for NvFp4 MoE, but it's not available. "
- "Falling back to %s for NvFp4 MoE",
- backend.value,
- scope="local",
+ def _return_or_raise(
+ backend: NvFp4MoeBackend,
+ config: FusedMoEConfig,
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ activation_format: mk.FusedMoEActivationFormat,
+ ) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]:
+ k_cls = backend_to_kernel_cls(backend)
+ supported, reason = k_cls.is_supported_config(
+ k_cls, config, weight_key, activation_key, activation_format
)
- else:
- logger.info_once(_make_log_backend(backend), scope="local")
- return backend
+ if supported:
+ logger.info_once(_make_log_backend(backend))
+ return backend, k_cls
+ raise ValueError(_make_log_unsupported(backend, reason))
+
+ if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"):
+ if not envs.VLLM_USE_FLASHINFER_MOE_FP4:
+ # If the user rejects FlashInfer remove those backends.
+ for b in FLASHINFER_NVFP4_MOE_BACKENDS:
+ AVAILABLE_BACKENDS.remove(b)
+
+ elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
+ # If user is explicit about backend, validate it.
+ fi_backend = get_flashinfer_moe_backend()
+
+ if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
+ backend = NvFp4MoeBackend.FLASHINFER_TRTLLM
+ supported, reason = is_supported_config_trtllm(
+ config, weight_key, activation_key, activation_format
+ )
+ if supported:
+ logger.info_once(_make_log_backend(backend))
+ return backend, None
+ else:
+ raise ValueError(_make_log_unsupported(backend, reason))
+ else:
+ backend = fi_2_vllm_backend_map[fi_backend]
+ return _return_or_raise(
+ backend, config, weight_key, activation_key, activation_format
+ )
+ else:
+ # If the user is not explicit about the backend, try each.
+ for backend in FLASHINFER_NVFP4_MOE_BACKENDS:
+ if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
+ k_cls = None
+ supported, reason = is_supported_config_trtllm(
+ config,
+ weight_key,
+ activation_key,
+ activation_format,
+ )
+ else:
+ k_cls = backend_to_kernel_cls(backend)
+ supported, reason = k_cls.is_supported_config(
+ k_cls,
+ config,
+ weight_key,
+ activation_key,
+ activation_format,
+ )
+ if supported:
+ logger.info_once(_make_log_backend(backend), scope="local")
+ return backend, None
+ else:
+ logger.debug_once(
+ _make_log_unsupported(backend, reason), scope="local"
+ )
+
+ raise NotImplementedError(
+ "Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no "
+ "FlashInfer NVFP4 MoE backend supports the configuration."
+ )
+
+ if envs.VLLM_TEST_FORCE_FP8_MARLIN:
+ backend = NvFp4MoeBackend.MARLIN
+ return _return_or_raise(
+ backend, config, weight_key, activation_key, activation_format
+ )
+
+ # Select kernels in order of backend.
+ for backend in AVAILABLE_BACKENDS:
+ if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
+ k_cls = None # type: ignore[assignment]
+ supported, reason = is_supported_config_trtllm(
+ config,
+ weight_key,
+ activation_key,
+ activation_format,
+ )
+ else:
+ k_cls = backend_to_kernel_cls(backend)
+ supported, reason = k_cls.is_supported_config(
+ k_cls,
+ config,
+ weight_key,
+ activation_key,
+ activation_format,
+ )
+
+ if supported:
+ logger.info_once(_make_log_backend(backend), scope="local")
+ return backend, k_cls
+ else:
+ logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
+
+ raise NotImplementedError(
+ "No NvFp4 MoE backend supports the deployment configuration."
+ )
def convert_to_nvfp4_moe_kernel_format(
@@ -238,55 +391,69 @@ def make_nvfp4_moe_quant_config(
)
-def make_nvfp4_moe_kernel(
- backend: NvFp4MoeBackend,
- quant_config: FusedMoEQuantConfig,
+def make_nvfp4_moe_kernel_for_mkm(
moe_config: FusedMoEConfig,
-) -> mk.FusedMoEModularKernel | None:
- assert moe_config.dp_size == 1
-
- UNSUPPORTED_BACKENDS = [
- # TRTLLM does not use the modular kernl abstraction.
- NvFp4MoeBackend.FLASHINFER_TRTLLM,
- # CUTEDSL is used with BATCHED (masked) format only.
- # TODO: add here once we support dp/ep via the oracle.
- NvFp4MoeBackend.FLASHINFER_CUTEDSL,
- ]
-
- if backend in UNSUPPORTED_BACKENDS:
- return None
-
- elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
- return mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
- FlashInferExperts(
- out_dtype=moe_config.in_dtype,
- quant_config=quant_config,
- ep_rank=moe_config.ep_rank,
- ep_size=moe_config.ep_size,
- tp_rank=moe_config.tp_rank,
- tp_size=moe_config.tp_size,
- use_dp=False,
- use_deepseek_fp8_block_scale=False,
- ),
+ quant_config: FusedMoEQuantConfig,
+ experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
+ prepare_finalize: mk.FusedMoEPrepareAndFinalize,
+) -> mk.FusedMoEPermuteExpertsUnpermute:
+ if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
+ max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
+ assert max_num_tokens_per_rank is not None
+ experts = experts_cls(
+ moe_config=moe_config,
+ quant_config=quant_config,
+ max_num_tokens=max_num_tokens_per_rank,
+ num_dispatchers=prepare_finalize.num_dispatchers(),
)
-
- elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
- return mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
- CutlassExpertsFp4(
- out_dtype=moe_config.in_dtype,
- # TODO(rob): see what impact this has on expert map?
- max_experts_per_worker=moe_config.num_experts,
- quant_config=quant_config,
- ),
- )
-
- elif backend == NvFp4MoeBackend.MARLIN:
- return mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(),
- MarlinExperts(quant_config=quant_config),
- )
-
else:
- raise ValueError(f"Unknown NvFp4 MoE backend: {backend}")
+ experts = experts_cls(
+ moe_config=moe_config,
+ quant_config=quant_config,
+ )
+
+ logger.debug_once("Using %s", experts.__class__.__name__)
+ return experts
+
+
+def make_nvfp4_moe_kernel(
+ moe_quant_config: FusedMoEQuantConfig,
+ moe_config: FusedMoEConfig,
+ experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
+) -> mk.FusedMoEModularKernel:
+ # TODO(rob): unify after we merge tp and dp/ep.
+ if (
+ moe_config.moe_parallel_config.use_all2all_kernels
+ and moe_config.moe_parallel_config.all2all_backend
+ not in ["allgather_reducescatter", "naive"]
+ ):
+ raise ValueError(
+ "NvFP4 Oracle should not create non-naive A2A P/F. "
+ "This should happen via the ModularKernelMethod."
+ )
+
+ # Create Prepare/Finalize.
+ prepare_finalize = MoEPrepareAndFinalizeNoEP(
+ defer_input_quant=experts_cls.expects_unquantized_inputs(
+ moe_config, moe_quant_config
+ ),
+ )
+
+ # Create Experts.
+ experts = experts_cls(
+ moe_config=moe_config,
+ quant_config=moe_quant_config,
+ )
+
+ # NOTE(rob): we only want the mk to control the shared_expert
+ # if using all2all (for SBO). bnell is making this explict in
+ # the new MoE runner class.
+ kernel = mk.FusedMoEModularKernel(
+ prepare_finalize,
+ experts,
+ shared_experts=None,
+ moe_parallel_config=moe_config.moe_parallel_config,
+ )
+
+ # TODO(rob): update inplace logic to be part of the kernel.
+ return kernel
diff --git a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py
index 138bfac28..14c3f84e6 100644
--- a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py
+++ b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py
@@ -123,7 +123,6 @@ def convert_to_unquantized_kernel_format(
def make_unquantized_moe_kernel(
- layer: torch.nn.Module,
backend: UnquantizedMoeBackend,
quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
@@ -141,12 +140,8 @@ def make_unquantized_moe_kernel(
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(
- out_dtype=layer.params_dtype,
+ moe_config=moe_config,
quant_config=quant_config,
- tp_rank=moe_config.moe_parallel_config.tp_rank,
- tp_size=moe_config.moe_parallel_config.tp_size,
- ep_rank=moe_config.moe_parallel_config.ep_rank,
- ep_size=moe_config.moe_parallel_config.ep_size,
),
)
use_inplace = False
@@ -157,13 +152,19 @@ def make_unquantized_moe_kernel(
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
- AiterExperts(quant_config),
+ AiterExperts(
+ moe_config=moe_config,
+ quant_config=quant_config,
+ ),
)
elif backend == UnquantizedMoeBackend.TRITON:
from vllm.model_executor.layers.fused_moe import TritonExperts
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
- TritonExperts(quant_config),
+ TritonExperts(
+ moe_config=moe_config,
+ quant_config=quant_config,
+ ),
)
return kernel, use_inplace
diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
index b78794c6b..6d8e5ca79 100644
--- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
@@ -9,11 +9,21 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
+ FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ kFp8Dynamic128Sym,
+ kFp8DynamicTensorSym,
+ kFp8DynamicTokenSym,
+ kFp8Static128BlockSym,
+ kFp8StaticChannelSym,
+ kFp8StaticTensorSym,
+)
class QuantMethod(IntEnum):
@@ -269,17 +279,49 @@ def rocm_aiter_fused_experts(
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
- def __init__(self, quant_config):
- super().__init__(quant_config)
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.Standard,
- mk.FusedMoEActivationFormat.Standard,
- )
+ @staticmethod
+ def expects_unquantized_inputs(
+ fused_moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig
+ ) -> bool:
+ # AITER fused MoE kernels handle input quantization internally.
+ return True
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ return rocm_aiter_ops.is_fused_moe_enabled()
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ return False
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ # TODO(rob): AITER also supports MXFP4, which is not
+ # yet supported via an Oracle. Once it is, we will add
+ # MXFP4 to this list.
+ SUPPORTED_W_A = [
+ (None, None),
+ (kFp8Static128BlockSym, kFp8Dynamic128Sym),
+ (kFp8StaticTensorSym, kFp8StaticTensorSym),
+ (kFp8StaticTensorSym, kFp8DynamicTensorSym),
+ (kFp8StaticChannelSym, kFp8DynamicTokenSym),
+ ]
+ return (weight_key, activation_key) in SUPPORTED_W_A
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ return activation in ["silu", "gelu"]
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ return True
def supports_expert_map(self):
return True
diff --git a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py
index a19dfb62b..17c59352f 100644
--- a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py
+++ b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py
@@ -34,6 +34,11 @@ class CustomRoutingRouter(BaseRouter):
@property
def routing_method_type(self) -> RoutingMethodType:
+ from vllm.model_executor.models.llama4 import Llama4MoE
+
+ # NOTE: FLASHINFER_TRTLLM support the Llama4 router.
+ if self.custom_routing_function == Llama4MoE.custom_routing_function:
+ return RoutingMethodType.Llama4
return RoutingMethodType.Custom
def _compute_routing(
diff --git a/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py b/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py
index e5b6de02f..1c908a2b4 100644
--- a/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py
+++ b/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py
@@ -261,7 +261,6 @@ class GroupedTopKRouter(BaseRouter):
num_fused_shared_experts: int = 0,
enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
- routing_method_type: RoutingMethodType | None = None,
):
super().__init__(
top_k=top_k,
@@ -278,13 +277,12 @@ class GroupedTopKRouter(BaseRouter):
self.e_score_correction_bias = e_score_correction_bias
self.num_fused_shared_experts = num_fused_shared_experts
- # Determine routing method type
- if routing_method_type is not None:
- self._routing_method_type = routing_method_type
- elif scoring_func == "sigmoid":
+ if scoring_func == "sigmoid":
self._routing_method_type = RoutingMethodType.DeepSeekV3
else:
- self._routing_method_type = RoutingMethodType.TopK
+ # NOTE: this prohibits the FLASHINFER_TRTLLM kernels from
+ # being selected, since they only support DeepSeek-style.
+ self._routing_method_type = RoutingMethodType.Unspecified
@property
def routing_method_type(self) -> RoutingMethodType:
diff --git a/vllm/model_executor/layers/fused_moe/router/router_factory.py b/vllm/model_executor/layers/fused_moe/router/router_factory.py
index 8818373d8..cbe294e6b 100644
--- a/vllm/model_executor/layers/fused_moe/router/router_factory.py
+++ b/vllm/model_executor/layers/fused_moe/router/router_factory.py
@@ -6,7 +6,6 @@ import torch
import vllm.envs as envs
from vllm.distributed.eplb.eplb_state import EplbLayerState
-from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
from vllm.model_executor.layers.fused_moe.router.custom_routing_router import (
CustomRoutingRouter,
@@ -36,7 +35,6 @@ def create_fused_moe_router(
global_num_experts: int,
renormalize: bool = True,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
- routing_method_type: RoutingMethodType | None = None,
# grouped topk parameters
use_grouped_topk: bool = False,
num_expert_group: int | None = None,
@@ -128,7 +126,6 @@ def create_fused_moe_router(
num_fused_shared_experts=num_fused_shared_experts,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
- routing_method_type=routing_method_type,
)
router.capture = capture
return router
diff --git a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py
index 09d5e45c1..f537f2f99 100644
--- a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py
+++ b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py
@@ -5,7 +5,10 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
-from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEQuantConfig,
+)
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
@@ -17,19 +20,22 @@ class TritonOrCutlassExperts(FallbackExperts):
def __init__(
self,
- e: int,
- n: int,
- k: int,
- out_dtype: torch.dtype | None,
+ moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
- device: torch.dtype,
):
self.is_sm100 = current_platform.has_device_capability(100)
super().__init__(
- experts=CutlassExpertsFp8(e, n, k, out_dtype, quant_config, device),
- fallback_experts=TritonExperts(quant_config),
+ experts=CutlassExpertsFp8(moe_config, quant_config),
+ fallback_experts=TritonExperts(moe_config, quant_config),
)
+ @staticmethod
+ def get_clses() -> tuple[
+ type[mk.FusedMoEPermuteExpertsUnpermute],
+ type[mk.FusedMoEPermuteExpertsUnpermute],
+ ]:
+ return (CutlassExpertsFp8, TritonExperts)
+
def workspace_shapes(
self,
M: int,
diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
index 55b1e1211..7e41269dc 100644
--- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
@@ -4,7 +4,10 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
-from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEQuantConfig,
+)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts,
_valid_deep_gemm,
@@ -20,12 +23,19 @@ from vllm.utils.deep_gemm import (
class TritonOrDeepGemmExperts(FallbackExperts):
"""DeepGemm with fallback to Triton for low latency shapes."""
- def __init__(self, quant_config: FusedMoEQuantConfig):
+ def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
super().__init__(
- experts=DeepGemmExperts(quant_config),
- fallback_experts=TritonExperts(quant_config),
+ experts=DeepGemmExperts(moe_config, quant_config),
+ fallback_experts=TritonExperts(moe_config, quant_config),
)
+ @staticmethod
+ def get_clses() -> tuple[
+ type[mk.FusedMoEPermuteExpertsUnpermute],
+ type[mk.FusedMoEPermuteExpertsUnpermute],
+ ]:
+ return (DeepGemmExperts, TritonExperts)
+
def workspace_shapes(
self,
M: int,
diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py
index c46f59564..29a3e9003 100644
--- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py
@@ -6,37 +6,73 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
+ FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+)
class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
- moe: FusedMoEConfig,
+ moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
gemm1_alpha,
gemm1_beta,
gemm1_clamp_limit,
max_capture_size,
):
- super().__init__(quant_config)
- self.moe = moe
+ super().__init__(moe_config, quant_config)
self.gemm1_alpha = gemm1_alpha
self.gemm1_beta = gemm1_beta
self.gemm1_clamp_limit = gemm1_clamp_limit
self.max_capture_size = max_capture_size
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- return (
- mk.FusedMoEActivationFormat.Standard,
- mk.FusedMoEActivationFormat.Standard,
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ raise NotImplementedError(
+ "TrtLlmGenExperts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ raise NotImplementedError(
+ "TrtLlmGenExperts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ raise NotImplementedError(
+ "TrtLlmGenExperts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_activation(activation: str) -> bool:
+ raise NotImplementedError(
+ "TrtLlmGenExperts is not yet used by an Oracle. "
+ "This method should not be called."
+ )
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ raise NotImplementedError(
+ "TrtLlmGenExperts is not yet used by an Oracle. "
+ "This method should not be called."
)
def supports_chunking(self) -> bool:
@@ -86,7 +122,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk = topk_ids.size(-1)
local_num_experts = w1.size(0)
intermediate_size = w2.size(1)
- local_expert_offset = self.moe.ep_rank * local_num_experts
+ local_expert_offset = self.moe_config.ep_rank * local_num_experts
x_quant = hidden_states
x_scale = a1q_scale
diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
index f8489ab06..2581a1e56 100644
--- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
+++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
@@ -96,13 +96,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
):
logger.debug("BatchedTritonExperts %s", self.moe)
return BatchedTritonExperts(
+ moe_config=self.moe,
+ quant_config=self.moe_quant_config,
max_num_tokens=self.moe.max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
- quant_config=self.moe_quant_config,
)
else:
logger.debug("TritonExperts %s", self.moe)
- return TritonExperts(self.moe_quant_config)
+ return TritonExperts(
+ moe_config=self.moe,
+ quant_config=self.moe_quant_config,
+ )
def create_weights(
self,
@@ -192,7 +196,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
assert self.moe_quant_config is not None
self.kernel, self.use_inplace = make_unquantized_moe_kernel(
- layer=layer,
backend=self.unquantized_backend,
quant_config=self.moe_quant_config,
moe_config=self.moe,
diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py
index b1fb67208..dfdfc7ea2 100644
--- a/vllm/model_executor/layers/quantization/awq_marlin.py
+++ b/vllm/model_executor/layers/quantization/awq_marlin.py
@@ -739,6 +739,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
+ moe_config=self.moe,
quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
@@ -749,6 +750,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
else:
# Standard Marlin experts for AWQ
return MarlinExperts(
+ moe_config=self.moe,
quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
index 13a123ba6..b48acdcf3 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
@@ -19,7 +19,6 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEActivationFormat,
- FusedMoEConfig,
FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute,
FusedMoERouter,
@@ -27,9 +26,9 @@ from vllm.model_executor.layers.fused_moe import (
UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
FusedMoEQuantConfig,
- fp8_w8a8_moe_quant_config,
- fp8_w8a16_moe_quant_config,
+ RoutingMethodType,
int4_w4a16_moe_quant_config,
int4_w4afp8_moe_quant_config,
int8_w8a8_moe_quant_config,
@@ -45,15 +44,17 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
+ make_fp8_moe_kernel_for_mkm,
+ make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
- FLASHINFER_NVFP4_MOE_BACKENDS,
NvFp4MoeBackend,
convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend,
make_mxfp4_moe_quant_config,
make_nvfp4_moe_kernel,
+ make_nvfp4_moe_kernel_for_mkm,
make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend,
)
@@ -62,10 +63,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compress
WNA16_SUPPORTED_TYPES_MAP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
- build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe,
- select_nvfp4_gemm_impl,
+)
+from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
+ apply_fi_trtllm_fp8_per_tensor_moe,
+ build_flashinfer_fp8_cutlass_moe_prepare_finalize,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_input_tensor_strategy_moe,
@@ -79,12 +82,18 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_moe_permute_scales,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
- is_fp4_marlin_supported,
prepare_moe_fp4_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
convert_bf16_scales_to_fp8,
convert_packed_uint4b8_to_signed_int4_inplace,
+ kFp8Dynamic128Sym,
+ kFp8DynamicTokenSym,
+ kFp8Static128BlockSym,
+ kFp8StaticChannelSym,
+ kFp8StaticTensorSym,
+ kNvfp4Dynamic,
+ kNvfp4Static,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
normalize_e4m3fn_to_e4m3fnuz,
@@ -200,7 +209,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
f"or None for NVFP4A16, found {input_quant}",
)
return CompressedTensorsW4A4Nvfp4MoEMethod(
- layer.moe_config, layer_name, use_marlin=input_quant is None
+ layer.moe_config, layer_name, use_a16=(input_quant is None)
)
elif (
quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
@@ -234,6 +243,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
super().__init__(moe)
self.group_size = 32
self.mxfp4_backend = NvFp4MoeBackend.MARLIN
+ self.experts_cls = MarlinExperts
self.kernel: mk.FusedMoEModularKernel | None = None
def create_weights(
@@ -327,9 +337,9 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config is not None:
self.kernel = make_nvfp4_moe_kernel(
- backend=self.mxfp4_backend,
- quant_config=self.moe_quant_config,
+ moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
+ experts_cls=self.experts_cls,
)
def apply(
@@ -368,34 +378,30 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
self,
moe: FusedMoEConfig,
layer_name: str | None = None,
- use_marlin: bool = False,
+ use_a16: bool = False,
):
super().__init__(moe)
self.group_size = 16
- if use_marlin:
- if is_fp4_marlin_supported():
- self.nvfp4_backend = NvFp4MoeBackend.MARLIN
- else:
- raise ValueError(
- "Marlin FP4 MoE kernel requested but not ",
- "supported on current platform.",
- )
- else:
- self.nvfp4_backend = select_nvfp4_moe_backend()
- # TODO: move this type of check into the oracle.
- if not self.moe.is_act_and_mul and self.nvfp4_backend not in [
- NvFp4MoeBackend.FLASHINFER_CUTLASS,
- NvFp4MoeBackend.MARLIN,
- ]:
- raise NotImplementedError(
- "Non-gated activations are only supported by FlashInfer "
- f"CUTLASS and Marlin NvFP4 MoE backends, not {self.nvfp4_backend}."
- )
+ # Select experts implementation.
+ self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend(
+ config=self.moe,
+ weight_key=kNvfp4Static,
+ activation_key=None if use_a16 else kNvfp4Dynamic,
+ )
+
+ # Delay creation of the kernel until after process-weights.
+ self.kernel: mk.FusedMoEModularKernel | None = None
+
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend
)
- self.kernel: mk.FusedMoEModularKernel | None = None
+
+ @property
+ def topk_indices_dtype(self) -> torch.dtype | None:
+ if self.kernel is not None:
+ return self.kernel.prepare_finalize.topk_indices_dtype()
+ return None
def create_weights(
self,
@@ -571,35 +577,40 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer.w13_input_scale = a13_scale
layer.w2_input_scale = a2_scale
- # Initialize the kernel that will be called in apply().
+ # Setup modular kernel for TP case and naive DP/EP case.
+ # In non-naive DP/EP case, we will create a ModularKernelMethod.
+ # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
+ # in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
- use_dp = self.moe.dp_size > 1
- if self.moe_quant_config is not None and not use_dp:
+ if self.moe_quant_config and (
+ (not self.moe.moe_parallel_config.use_all2all_kernels)
+ or self.moe.moe_parallel_config.use_naive_all2all_kernels
+ ):
+ assert self.experts_cls is not None
self.kernel = make_nvfp4_moe_kernel(
- backend=self.nvfp4_backend,
- quant_config=self.moe_quant_config,
+ moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
+ experts_cls=self.experts_cls,
)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
- UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM]
- if self.nvfp4_backend in UNSUPPORTED:
+ if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
return None
elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
- # TP case: avoid convert to ModularKernelMethod - to be refactored.
- if self.moe.dp_size == 1:
+ # For no-EP case, don't use the MKM framework.
+ if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
- # For now, fp4 moe only works with the flashinfer dispatcher.
- prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
- self.moe
+
+ prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
+ self.moe,
+ use_deepseek_fp8_block_scale=False,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
- else:
- return super().maybe_make_prepare_finalize(routing_tables)
+ return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
self,
@@ -607,14 +618,13 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
- """Return the appropriate GEMM experts implementation."""
- experts = select_nvfp4_gemm_impl(
- self.moe,
- self.moe_quant_config,
- allow_flashinfer=(self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS),
+ assert self.experts_cls is not None
+ return make_nvfp4_moe_kernel_for_mkm(
+ moe_config=self.moe,
+ quant_config=self.moe_quant_config,
+ experts_cls=self.experts_cls,
+ prepare_finalize=prepare_finalize,
)
- logger.debug_once("Using %s", experts.__class__.__name__)
- return experts
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
@@ -727,33 +737,41 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization."
)
- self.fp8_backend = select_fp8_moe_backend(
- block_quant=self.block_quant,
- tp_size=moe.tp_size,
- with_lora_support=moe.is_lora_enabled,
- is_act_and_mul=moe.is_act_and_mul,
- # TODO(rob): enable selecting this externally.
+
+ ct2vllm_weight = {
+ QuantizationStrategy.CHANNEL: kFp8StaticChannelSym,
+ QuantizationStrategy.TENSOR: kFp8StaticTensorSym,
+ QuantizationStrategy.BLOCK: kFp8Static128BlockSym,
+ }
+ ct2vllm_act = {
+ QuantizationStrategy.TOKEN: kFp8DynamicTokenSym,
+ QuantizationStrategy.TENSOR: (
+ kFp8StaticTensorSym if self.static_input_scales else kFp8Dynamic128Sym
+ ),
+ }
+ weight_key = ct2vllm_weight[self.weight_quant.strategy]
+ if weight_key == kFp8Static128BlockSym:
+ activation_key = kFp8Dynamic128Sym
+ else:
+ activation_key = ct2vllm_act[self.input_quant.strategy]
+
+ # Select Fp8 MoE backend
+ self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
+ config=self.moe,
+ weight_key=weight_key,
+ activation_key=activation_key,
allow_vllm_cutlass=True,
)
- if self.fp8_backend != Fp8MoeBackend.MARLIN:
- per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
- per_channel_quant = (
- self.weight_quant.strategy == QuantizationStrategy.CHANNEL
- )
- if per_act_token != per_channel_quant:
- raise NotImplementedError(
- "For FP8 Fused MoE layers, per-token and per-channel must be "
- "used together."
- )
- # TODO(rob): hook this up in a follow up PR.
- if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
- raise NotImplementedError(
- "FlashInfer TRTLLM backend not supported for compressed-tensors yet."
- )
- self.disable_expert_map = False
+ # Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
+ @property
+ def topk_indices_dtype(self) -> torch.dtype | None:
+ if self.kernel is not None:
+ return self.kernel.prepare_finalize.topk_indices_dtype()
+ return None
+
def create_weights(
self,
layer: torch.nn.Module,
@@ -970,140 +988,75 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
replace_parameter(layer, "w13_weight_scale", w13_scale)
replace_parameter(layer, "w2_weight_scale", w2_scale)
+ # Setup modular kernel for TP case and naive DP/EP case.
+ # In non-naive DP/EP case, we will create a ModularKernelMethod.
+ # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
+ # in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
- if self.moe_quant_config:
+ if self.moe_quant_config and (
+ (not self.moe.moe_parallel_config.use_all2all_kernels)
+ or self.moe.moe_parallel_config.use_naive_all2all_kernels
+ ):
+ assert self.experts_cls is not None
self.kernel, self.use_inplace = make_fp8_moe_kernel(
- layer=layer,
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
+ experts_cls=self.experts_cls,
)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
- if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
+ if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
- else:
- return super().maybe_make_prepare_finalize(routing_tables)
+ elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
+ # For no-EP case, don't use the MKM framework.
+ if not self.moe.moe_parallel_config.use_all2all_kernels:
+ return None
+
+ prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
+ self.moe,
+ use_deepseek_fp8_block_scale=self.block_quant,
+ )
+ logger.debug_once("%s", prepare_finalize.__class__.__name__)
+ return prepare_finalize
+ return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
- # cutlass path
assert self.moe_quant_config is not None
- if self.fp8_backend == Fp8MoeBackend.VLLM_CUTLASS:
- from vllm.model_executor.layers.fused_moe import (
- CutlassBatchedExpertsFp8,
- CutlassExpertsFp8,
- )
-
- experts: FusedMoEPermuteExpertsUnpermute
-
- num_dispatchers = prepare_finalize.num_dispatchers()
-
- if (
- prepare_finalize.activation_format
- == FusedMoEActivationFormat.BatchedExperts
- ):
- logger.debug("CutlassBatchedExpertsFp8(%s)", self.__class__.__name__)
- experts = CutlassBatchedExpertsFp8(
- max_experts_per_worker=self.moe.num_local_experts,
- num_dispatchers=num_dispatchers,
- out_dtype=self.moe.in_dtype,
- e=layer.local_num_experts,
- n=layer.intermediate_size_per_partition,
- k=layer.hidden_size,
- device=layer.w13_weight.device,
- quant_config=self.moe_quant_config,
- )
- else:
- logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
- experts = CutlassExpertsFp8(
- out_dtype=self.moe.in_dtype,
- e=layer.local_num_experts,
- n=layer.intermediate_size_per_partition,
- k=layer.hidden_size,
- device=layer.w13_weight.device,
- quant_config=self.moe_quant_config,
- )
-
- # TODO(rob): investigate disable_expert_map
- self.disable_expert_map = (
- num_dispatchers > 1 or not experts.supports_expert_map()
- )
-
- return experts
-
- from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
- BatchedDeepGemmExperts,
+ assert self.experts_cls is not None
+ return make_fp8_moe_kernel_for_mkm(
+ moe_config=self.moe,
+ quant_config=self.moe_quant_config,
+ experts_cls=self.experts_cls,
+ prepare_finalize=prepare_finalize,
)
- from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
- BatchedTritonExperts,
- )
- from vllm.model_executor.layers.fused_moe.fused_moe import (
- TritonExperts,
- )
- from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
- TritonOrDeepGemmExperts,
- )
-
- assert self.fp8_backend not in [Fp8MoeBackend.AITER, Fp8MoeBackend.MARLIN]
-
- if (
- prepare_finalize.activation_format
- == FusedMoEActivationFormat.BatchedExperts
- ):
- max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
- assert max_num_tokens_per_rank is not None
-
- if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
- logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__)
- return BatchedDeepGemmExperts(
- max_num_tokens=max_num_tokens_per_rank,
- num_dispatchers=prepare_finalize.num_dispatchers(),
- quant_config=self.moe_quant_config,
- )
- else:
- logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
- return BatchedTritonExperts(
- max_num_tokens=max_num_tokens_per_rank,
- num_dispatchers=prepare_finalize.num_dispatchers(),
- quant_config=self.moe_quant_config,
- )
-
- else:
- if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
- logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__)
- return TritonOrDeepGemmExperts(self.moe_quant_config)
- else:
- logger.debug("TritonExperts(%s)", self.__class__.__name__)
- return TritonExperts(self.moe_quant_config)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
- if self.fp8_backend == Fp8MoeBackend.MARLIN:
- return fp8_w8a16_moe_quant_config(
- w1_scale=layer.w13_weight_scale,
- w2_scale=layer.w2_weight_scale,
- block_shape=self.weight_block_size,
- )
+ w1_scale = layer.w13_weight_scale
+ w2_scale = layer.w2_weight_scale
+ a1_scale = layer.w13_input_scale
+ a2_scale = layer.w2_input_scale
- per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
- per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL
-
- return fp8_w8a8_moe_quant_config(
- w1_scale=layer.w13_weight_scale,
- w2_scale=layer.w2_weight_scale,
- a1_scale=layer.w13_input_scale,
- a2_scale=layer.w2_input_scale,
- per_act_token_quant=per_act_token,
- per_out_ch_quant=per_channel_quant,
- block_shape=layer.weight_block_size,
+ return make_fp8_moe_quant_config(
+ fp8_backend=self.fp8_backend,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ a1_scale=a1_scale,
+ a2_scale=a2_scale,
+ per_act_token_quant=(
+ self.input_quant.strategy == QuantizationStrategy.TOKEN
+ ),
+ per_out_ch_quant=(self.input_quant.strategy == QuantizationStrategy.TOKEN),
+ block_shape=self.weight_block_size,
)
def apply(
@@ -1113,6 +1066,56 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+ if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
+ if layer.enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for `FlashInfer TRTLLM FP8 MoE`."
+ )
+ assert layer.activation == "silu"
+
+ if self.block_quant:
+ import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
+
+ e_score_correction_bias = (
+ layer.e_score_correction_bias.to(x.dtype)
+ if layer.e_score_correction_bias is not None
+ else None
+ )
+ routing_method_type = layer.routing_method_type
+ return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
+ routing_logits=router_logits.to(torch.float32)
+ if routing_method_type == RoutingMethodType.DeepSeekV3
+ else router_logits,
+ routing_bias=e_score_correction_bias,
+ x=x,
+ w13_weight=layer.w13_weight,
+ w13_weight_scale_inv=layer.w13_weight_scale,
+ w2_weight=layer.w2_weight,
+ w2_weight_scale_inv=layer.w2_weight_scale,
+ global_num_experts=layer.global_num_experts,
+ top_k=layer.top_k,
+ num_expert_group=layer.num_expert_group,
+ topk_group=layer.topk_group,
+ intermediate_size=layer.intermediate_size_per_partition,
+ expert_offset=layer.ep_rank * layer.local_num_experts,
+ local_num_experts=layer.local_num_experts,
+ block_shape=self.weight_block_size,
+ routing_method_type=routing_method_type,
+ routed_scaling=layer.routed_scaling_factor,
+ )
+ else:
+ return apply_fi_trtllm_fp8_per_tensor_moe(
+ layer=layer,
+ hidden_states=x,
+ router_logits=router_logits,
+ routing_bias=layer.e_score_correction_bias,
+ global_num_experts=layer.global_num_experts,
+ top_k=layer.top_k,
+ num_expert_group=layer.num_expert_group,
+ topk_group=layer.topk_group,
+ apply_router_weight_on_input=layer.apply_router_weight_on_input,
+ )
+
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
@@ -1130,7 +1133,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts,
# TODO(rob): investigate the disable_expert_map introduced by:
# https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501
- expert_map=None if self.disable_expert_map else layer.expert_map,
+ expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
@@ -1596,6 +1599,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
+ moe_config=self.moe,
quant_config=self.moe_quant_config,
w13_g_idx=layer.w13_weight_g_idx,
w2_g_idx=layer.w2_weight_g_idx,
@@ -1605,6 +1609,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
)
else:
return MarlinExperts(
+ moe_config=self.moe,
quant_config=self.moe_quant_config,
w13_g_idx=layer.w13_weight_g_idx,
w2_g_idx=layer.w2_weight_g_idx,
@@ -1854,7 +1859,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight = layer.w13_weight_packed
layer.w2_weight = layer.w2_weight_packed
- return TritonWNA16Experts(quant_config=self.moe_quant_config)
+ return TritonWNA16Experts(
+ moe_config=self.moe, quant_config=self.moe_quant_config
+ )
else:
raise NotImplementedError(
"TritonExperts requires Triton. "
@@ -2467,6 +2474,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
c_strides2=self.a_strides1_c_strides2,
s_strides1=self.s_strides1,
s_strides2=self.s_strides2,
+ moe_config=self.moe,
quant_config=self.moe_quant_config,
group_size=self.group_size,
)
@@ -2505,6 +2513,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight_packed,
topk_weights,
topk_ids,
+ moe_config=self.moe,
quant_config=self.moe_quant_config,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index 5adcd09b0..d9fa02c6b 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -19,7 +19,6 @@ from vllm.model_executor.layers.batch_invariant import (
)
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
- FusedMoEActivationFormat,
FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
@@ -35,6 +34,7 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
+ make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
@@ -55,7 +55,6 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
- select_cutlass_fp8_gemm_impl,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
@@ -79,8 +78,10 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
is_layer_skipped,
+ kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
+ kFp8Static128BlockSym,
kFp8StaticTensorSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
@@ -658,38 +659,36 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.weight_scale_name = (
"weight_scale_inv" if self.block_quant else "weight_scale"
)
- self.fp8_backend = select_fp8_moe_backend(
- block_quant=self.block_quant,
- tp_size=layer.moe_parallel_config.tp_size,
- with_lora_support=self.moe.is_lora_enabled,
- is_act_and_mul=self.moe.is_act_and_mul,
- )
- if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
- if self.block_quant and self.weight_block_size != [128, 128]:
- raise NotImplementedError(
- "FlashInfer CUTLASS FP8 MoE backend only supports block "
- "size [128, 128]."
- )
- if layer.activation != "silu":
- raise NotImplementedError(
- "FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
- "activation function, but got {layer.activation}."
- )
- dynamic_per_token = (
- not self.block_quant and self.quant_config.activation_scheme != "static"
- )
- if dynamic_per_token and self.fp8_backend in [
- Fp8MoeBackend.FLASHINFER_TRTLLM,
- Fp8MoeBackend.FLASHINFER_CUTLASS,
- ]:
- raise NotImplementedError(
- "FlashInfer FP8 MoE backend does not support dynamic per token "
- "activation quantization."
+ # Set weight key and activation key for kernel compatibility
+ if self.block_quant:
+ weight_key = kFp8Static128BlockSym
+ activation_key = kFp8Dynamic128Sym
+ else:
+ weight_key = kFp8StaticTensorSym
+ activation_key = (
+ kFp8StaticTensorSym
+ if self.quant_config.activation_scheme == "static"
+ else kFp8DynamicTensorSym
)
+ # Select Fp8 MoE backend
+ self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
+ config=self.moe,
+ weight_key=weight_key,
+ activation_key=activation_key,
+ allow_vllm_cutlass=False,
+ )
+
+ # Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
+ @property
+ def topk_indices_dtype(self) -> torch.dtype | None:
+ if self.kernel is not None:
+ return self.kernel.prepare_finalize.topk_indices_dtype()
+ return None
+
def create_weights(
self,
layer: Module,
@@ -842,14 +841,21 @@ class Fp8MoEMethod(FusedMoEMethodBase):
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
- # Setup modular kernel for TP case.
+ # Setup modular kernel for TP case and naive DP/EP case.
+ # In non-naive DP/EP case, we will create a ModularKernelMethod.
+ # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
+ # in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
- if self.moe_quant_config:
+ if self.moe_quant_config and (
+ (not self.moe.moe_parallel_config.use_all2all_kernels)
+ or self.moe.moe_parallel_config.use_naive_all2all_kernels
+ ):
+ assert self.experts_cls is not None
self.kernel, self.use_inplace = make_fp8_moe_kernel(
- layer=layer,
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
+ experts_cls=self.experts_cls,
)
def process_weights_after_loading(self, layer: Module) -> None:
@@ -904,13 +910,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
- if self.fp8_backend in [
- Fp8MoeBackend.AITER,
- Fp8MoeBackend.MARLIN,
- Fp8MoeBackend.FLASHINFER_TRTLLM,
- ]:
+ if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
+ # For no-EP case, don't use the MKM framework.
+ if not self.moe.moe_parallel_config.use_all2all_kernels:
+ return None
+
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
@@ -924,73 +930,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
- from vllm.model_executor.layers.fused_moe import (
- BatchedDeepGemmExperts,
- BatchedTritonExperts,
- TritonExperts,
- TritonOrDeepGemmExperts,
- )
-
- if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
- raise NotImplementedError(
- "Marlin and ROCm AITER are not supported with all2all yet."
- )
-
assert self.moe_quant_config is not None
-
- if (
- prepare_finalize.activation_format
- == FusedMoEActivationFormat.BatchedExperts
- ):
- max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
- assert max_num_tokens_per_rank is not None
-
- experts_impl = (
- BatchedDeepGemmExperts
- if self.fp8_backend == Fp8MoeBackend.DEEPGEMM
- else BatchedTritonExperts
- )
- logger.debug(
- "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
- experts_impl.__name__,
- self.__class__.__name__,
- max_num_tokens_per_rank,
- self.weight_block_size,
- False,
- )
- return experts_impl(
- max_num_tokens=max_num_tokens_per_rank,
- num_dispatchers=prepare_finalize.num_dispatchers(),
- quant_config=self.moe_quant_config,
- )
- elif self.moe.is_lora_enabled:
- return TritonExperts(quant_config=self.moe_quant_config)
- elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
- # Select GEMM experts with block-scale when weights are block-quantized
- experts = select_cutlass_fp8_gemm_impl(
- self.moe,
- self.moe_quant_config,
- use_deepseek_fp8_block_scale=self.block_quant,
- )
- logger.debug_once("Using %s", experts.__class__.__name__)
- return experts
- elif self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
- logger.debug(
- "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
- self.__class__.__name__,
- self.weight_block_size,
- False,
- )
- return TritonOrDeepGemmExperts(self.moe_quant_config)
- else:
- assert self.fp8_backend == Fp8MoeBackend.TRITON
- logger.debug(
- "TritonExperts(%s): block_size=%s, per_act_token=%s",
- self.__class__.__name__,
- self.weight_block_size,
- False,
- )
- return TritonExperts(self.moe_quant_config)
+ assert self.experts_cls is not None
+ return make_fp8_moe_kernel_for_mkm(
+ moe_config=self.moe,
+ quant_config=self.moe_quant_config,
+ experts_cls=self.experts_cls,
+ prepare_finalize=prepare_finalize,
+ )
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
@@ -1067,7 +1014,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routed_scaling=layer.routed_scaling_factor,
)
else:
- result = apply_fi_trtllm_fp8_per_tensor_moe(
+ return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py
index 00cd635b2..ec295005b 100644
--- a/vllm/model_executor/layers/quantization/gptq_marlin.py
+++ b/vllm/model_executor/layers/quantization/gptq_marlin.py
@@ -875,6 +875,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
+ moe_config=self.moe,
quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
@@ -885,6 +886,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
else:
# Standard Marlin experts for GPTQ
return MarlinExperts(
+ moe_config=self.moe,
quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py
index 91dfa03b8..0de9cb88d 100644
--- a/vllm/model_executor/layers/quantization/modelopt.py
+++ b/vllm/model_executor/layers/quantization/modelopt.py
@@ -27,15 +27,16 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
+ make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
- FLASHINFER_NVFP4_MOE_BACKENDS,
NvFp4MoeBackend,
convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend,
make_nvfp4_moe_kernel,
+ make_nvfp4_moe_kernel_for_mkm,
make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend,
)
@@ -57,12 +58,10 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe,
- select_nvfp4_gemm_impl,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
- select_cutlass_fp8_gemm_impl,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
@@ -84,6 +83,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kFp8StaticTokenSym,
+ kNvfp4Dynamic,
+ kNvfp4Static,
swizzle_blockscale,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
@@ -728,14 +729,23 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
super().__init__(moe_config)
self.quant_config = quant_config
assert self.quant_config.is_checkpoint_fp8_serialized
- self.fp8_backend = select_fp8_moe_backend(
- block_quant=False,
- tp_size=moe_config.moe_parallel_config.tp_size,
- with_lora_support=self.moe.is_lora_enabled,
- is_act_and_mul=self.moe.is_act_and_mul,
+
+ # Select Fp8 MoE backend
+ self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
+ config=self.moe,
+ weight_key=kFp8StaticTensorSym,
+ activation_key=kFp8StaticTensorSym,
)
+
+ # Delay creation of the kernel until after process-weights.
self.kernel: mk.FusedMoEModularKernel | None = None
+ @property
+ def topk_indices_dtype(self) -> torch.dtype | None:
+ if self.kernel is not None:
+ return self.kernel.prepare_finalize.topk_indices_dtype()
+ return None
+
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
@@ -744,8 +754,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
- # TP case: avoid convert to ModularKernelMethod - to be refactored.
- if self.moe.dp_size == 1:
+ # For no-EP case, don't use the MKM framework.
+ if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
@@ -762,12 +772,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
- experts = select_cutlass_fp8_gemm_impl(
- self.moe,
- self.moe_quant_config,
+ assert self.experts_cls is not None
+ return make_fp8_moe_kernel_for_mkm(
+ moe_config=self.moe,
+ quant_config=self.moe_quant_config,
+ experts_cls=self.experts_cls,
+ prepare_finalize=prepare_finalize,
)
- logger.debug_once("Using %s", experts.__class__.__name__)
- return experts
def create_weights(
self,
@@ -876,14 +887,15 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
replace_parameter(layer, "w13_weight_scale", w13_scale)
replace_parameter(layer, "w2_weight_scale", w2_scale)
- # Setup modular kernel for TP case.
+ # Setup modular kernel.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
+ assert self.experts_cls is not None
self.kernel, self.use_inplace = make_fp8_moe_kernel(
- layer=layer,
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
+ experts_cls=self.experts_cls,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
@@ -1335,32 +1347,35 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
) -> None:
super().__init__(moe_config)
self.quant_config = quant_config
- self.nvfp4_backend = select_nvfp4_moe_backend()
- # TODO: move this type of check into the oracle.
- if not self.moe.is_act_and_mul and self.nvfp4_backend not in [
- NvFp4MoeBackend.FLASHINFER_CUTLASS,
- NvFp4MoeBackend.MARLIN,
- ]:
- raise NotImplementedError(
- "Non-gated activations are only supported by FlashInfer "
- f"CUTLASS and Marlin NvFP4 MoE backends, not {self.nvfp4_backend}."
- )
+ # Select experts implementation.
+ self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend(
+ config=self.moe,
+ weight_key=kNvfp4Static,
+ activation_key=kNvfp4Dynamic,
+ )
+
+ # Delay creation of the kernel until after process-weights.
+ self.kernel: mk.FusedMoEModularKernel | None = None
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend
)
- self.kernel: mk.FusedMoEModularKernel | None = None
+
+ @property
+ def topk_indices_dtype(self) -> torch.dtype | None:
+ if self.kernel is not None:
+ return self.kernel.prepare_finalize.topk_indices_dtype()
+ return None
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
- UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM]
- if self.nvfp4_backend in UNSUPPORTED:
+ if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
return None
elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
- # TP case: avoid convert to ModularKernelMethod - to be refactored.
- if self.moe.dp_size == 1:
+ # For no-EP case, don't use the MKM framework.
+ if not self.moe.moe_parallel_config.use_all2all_kernels:
return None
# For now, fp4 moe only works with the flashinfer dispatcher.
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
@@ -1377,13 +1392,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
- experts = select_nvfp4_gemm_impl(
- self.moe,
- self.moe_quant_config,
- allow_flashinfer=self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS,
+ assert self.experts_cls is not None
+ return make_nvfp4_moe_kernel_for_mkm(
+ moe_config=self.moe,
+ quant_config=self.moe_quant_config,
+ experts_cls=self.experts_cls,
+ prepare_finalize=prepare_finalize,
)
- logger.debug_once("Using %s", experts.__class__.__name__)
- return experts
def uses_weight_scale_2_pattern(self) -> bool:
"""
@@ -1554,13 +1569,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
replace_parameter(layer, "w2_input_scale", a2_scale)
+ # Setup modular kernel for TP case and naive DP/EP case.
+ # In non-naive DP/EP case, we will create a ModularKernelMethod.
+ # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
+ # in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
- use_dp = self.moe.dp_size > 1
- if self.moe_quant_config is not None and not use_dp:
+ if self.moe_quant_config and (
+ (not self.moe.moe_parallel_config.use_all2all_kernels)
+ or self.moe.moe_parallel_config.use_naive_all2all_kernels
+ ):
+ assert self.experts_cls is not None
self.kernel = make_nvfp4_moe_kernel(
- backend=self.nvfp4_backend,
- quant_config=self.moe_quant_config,
+ moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
+ experts_cls=self.experts_cls,
)
@property
diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py
index ecd13e5c7..18dd3e40b 100644
--- a/vllm/model_executor/layers/quantization/mxfp4.py
+++ b/vllm/model_executor/layers/quantization/mxfp4.py
@@ -853,6 +853,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
+ moe_config=self.moe,
)
else:
raise NotImplementedError(
@@ -875,11 +876,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
}
return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs)
elif self.mxfp4_backend == Mxfp4Backend.MARLIN:
- return MarlinExperts(self.moe_quant_config)
+ return MarlinExperts(self.moe, self.moe_quant_config)
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
if self.moe.is_lora_enabled:
- return UnfusedOAITritonExperts(self.moe_quant_config)
- return OAITritonExperts(self.moe_quant_config)
+ return UnfusedOAITritonExperts(self.moe, self.moe_quant_config)
+ return OAITritonExperts(self.moe, self.moe_quant_config)
else:
raise NotImplementedError(
f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP"
diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
index 272b13861..ea5884e0f 100644
--- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
+++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
@@ -11,19 +11,16 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
- FusedMoEQuantConfig,
+ FusedMoEParallelConfig,
RoutingMethodType,
)
-from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
- FlashInferCuteDSLExperts,
-)
-from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
- FlashInferExperts,
-)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ kNvfp4Dynamic,
+ kNvfp4Static,
swizzle_blockscale,
)
from vllm.platforms import current_platform
@@ -47,6 +44,86 @@ __all__ = [
"build_flashinfer_fp4_cutlass_moe_prepare_finalize",
]
+#
+# Methods used by the oracle for kernel selection.
+#
+
+
+def _supports_current_device() -> bool:
+ """Supports only Blackwell-family GPUs."""
+ p = current_platform
+ return p.is_cuda() and p.is_device_capability_family(100)
+
+
+def _supports_no_act_and_mul() -> bool:
+ """Does not support non-gated MoE (i.e. Nemotron-Nano)."""
+ return False
+
+
+def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+) -> bool:
+ """Supports Nvfp4 quantization."""
+ SUPPORTED_W_A = [
+ (kNvfp4Static, kNvfp4Dynamic),
+ ]
+ return (weight_key, activation_key) in SUPPORTED_W_A
+
+
+def _supports_activation(activation: str) -> bool:
+ """Supports silu activation only."""
+ return activation in ["silu"]
+
+
+def _supports_routing_method(
+ routing_method: RoutingMethodType,
+) -> bool:
+ """Monolithic kernels need to express router support."""
+ # NOTE(rob): potentially allow others here. This is a conservative list.
+ return routing_method in [
+ RoutingMethodType.DeepSeekV3,
+ RoutingMethodType.Renormalize,
+ RoutingMethodType.RenormalizeNaive,
+ RoutingMethodType.Llama4,
+ ]
+
+
+def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ """Supports EP."""
+ return True
+
+
+def is_supported_config_trtllm(
+ moe_config: FusedMoEConfig,
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ activation_format: mk.FusedMoEActivationFormat,
+) -> tuple[bool, str | None]:
+ """
+ This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
+ """
+
+ def _make_reason(reason: str) -> str:
+ return f"kernel does not support {reason}"
+
+ if not _supports_current_device():
+ return False, _make_reason("current device")
+ elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
+ return False, _make_reason("no act_and_mul MLP layer")
+ elif not _supports_activation(moe_config.activation):
+ return False, _make_reason(f"{moe_config.activation} activation")
+ elif not _supports_quant_scheme(weight_key, activation_key):
+ return False, _make_reason("quantization scheme")
+ elif not _supports_parallel_config(moe_config.moe_parallel_config):
+ return False, _make_reason("parallel config")
+ elif not _supports_routing_method(moe_config.routing_method):
+ return False, _make_reason("routing method")
+ elif activation_format != mk.FusedMoEActivationFormat.Standard:
+ return False, _make_reason("activation format")
+
+ return True, None
+
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
"""Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
@@ -96,37 +173,6 @@ def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
)
-def select_nvfp4_gemm_impl(
- moe: FusedMoEConfig,
- moe_quant_config: FusedMoEQuantConfig,
- allow_flashinfer: bool,
-) -> mk.FusedMoEPermuteExpertsUnpermute:
- """Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
-
- if allow_flashinfer:
- if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm":
- return FlashInferCuteDSLExperts(
- out_dtype=moe.in_dtype,
- quant_config=moe_quant_config,
- )
- elif envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput":
- return FlashInferExperts(
- out_dtype=moe.in_dtype,
- quant_config=moe_quant_config,
- ep_rank=moe.moe_parallel_config.ep_rank,
- ep_size=moe.moe_parallel_config.ep_size,
- tp_rank=moe.moe_parallel_config.tp_rank,
- tp_size=moe.moe_parallel_config.tp_size,
- use_dp=moe.moe_parallel_config.dp_size > 1,
- )
-
- # native cutlass experts currently don't support DP; TP case won't call this
- raise ValueError(
- "CutlassExpertsFp4 doesn't support DP. Use flashinfer CUTLASS "
- "Fused MoE backend instead (set VLLM_USE_FLASHINFER_MOE_FP4=1)"
- )
-
-
def prepare_static_weights_for_trtllm_fp4_moe(
# args_dequant,
# args,
diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
index 799854479..2cc17b12f 100644
--- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
@@ -9,10 +9,6 @@ from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
- FusedMoEQuantConfig,
-)
-from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
- FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
@@ -203,33 +199,6 @@ def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
)
-def select_cutlass_fp8_gemm_impl(
- moe: FusedMoEConfig | None,
- quant_config: FusedMoEQuantConfig,
- out_dtype: torch.dtype | None = None,
- use_deepseek_fp8_block_scale: bool = False,
-) -> mk.FusedMoEPermuteExpertsUnpermute:
- """Return a GEMM *experts* implementation for fused-MoE layers"""
-
- if moe is not None:
- return FlashInferExperts(
- out_dtype=moe.in_dtype,
- quant_config=quant_config,
- ep_rank=moe.moe_parallel_config.ep_rank,
- ep_size=moe.moe_parallel_config.ep_size,
- tp_rank=moe.moe_parallel_config.tp_rank,
- tp_size=moe.moe_parallel_config.tp_size,
- use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
- )
-
- assert out_dtype is not None, "If moe config is None, out_dtype must be passed"
- return FlashInferExperts(
- out_dtype=out_dtype,
- quant_config=quant_config,
- use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
- )
-
-
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
backend_map = {
"throughput": FlashinferMoeBackend.CUTLASS,
diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py
index 91fc8760b..bc7458444 100644
--- a/vllm/model_executor/layers/quantization/utils/quant_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py
@@ -48,6 +48,7 @@ class GroupShape(_GroupShape):
# Aliases for common quantization group shapes
PER_TENSOR: ClassVar["GroupShape"]
PER_TOKEN: ClassVar["GroupShape"]
+ PER_CHANNEL: ClassVar["GroupShape"]
def is_per_tensor(self) -> bool:
return self.row == -1 and self.col == -1
@@ -55,12 +56,16 @@ class GroupShape(_GroupShape):
def is_per_token(self) -> bool:
return self.row == 1 and self.col == -1
+ def is_per_channel(self) -> bool:
+ return self.row == -1 and self.col == 1
+
def is_per_group(self) -> bool:
return self.row == 1 and self.col >= 1
GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1)
+GroupShape.PER_CHANNEL = GroupShape(-1, 1)
@dataclass(frozen=True)
@@ -77,16 +82,12 @@ class ScaleDesc:
group_shape: GroupShape
def __str__(self):
- group_shape = (
- "per_tensor"
- if self.group_shape == GroupShape.PER_TENSOR
- else (
- "per_token"
- if self.group_shape == GroupShape.PER_TOKEN
- else str(self.group_shape)
- )
- )
-
+ d = {
+ GroupShape.PER_TENSOR: "per_tensor",
+ GroupShape.PER_TOKEN: "per_token",
+ GroupShape.PER_CHANNEL: "per_channel",
+ }
+ group_shape = d.get(self.group_shape, str(self.group_shape))
return (
f"{fx.graph.dtype_abbrs[self.dtype]},"
f"{'static' if self.static else 'dynamic'},{group_shape}"
@@ -126,15 +127,28 @@ kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True)
kStaticTokenScale = ScaleDesc(torch.float32, True, GroupShape.PER_TOKEN)
kFp8StaticTokenSym = QuantKey(FP8_DTYPE, kStaticTokenScale, symmetric=True)
+kStaticChannelScale = ScaleDesc(torch.float32, True, GroupShape.PER_CHANNEL)
+kFp8StaticChannelSym = QuantKey(FP8_DTYPE, kStaticChannelScale, symmetric=True)
+
kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)
-kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16))
-kNvfp4Quant = QuantKey(FP4_DTYPE, scale=kNvfp4GroupScale, scale2=kStaticTensorScale)
+kNvfp4DynamicGroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16))
+kNvfp4Dynamic = QuantKey(
+ FP4_DTYPE, scale=kNvfp4DynamicGroupScale, scale2=kStaticTensorScale
+)
+
+kNvfp4StaticGroupScale = ScaleDesc(FP8_DTYPE, True, GroupShape(1, 16))
+kNvfp4Static = QuantKey(
+ FP4_DTYPE, scale=kNvfp4StaticGroupScale, scale2=kStaticTensorScale
+)
kDynamic128Scale = ScaleDesc(torch.float32, False, GroupShape(1, 128))
kFp8Dynamic128Sym = QuantKey(FP8_DTYPE, kDynamic128Scale, symmetric=True)
+kStatic128BlockScale = ScaleDesc(torch.float32, True, GroupShape(128, 128))
+kFp8Static128BlockSym = QuantKey(FP8_DTYPE, kStatic128BlockScale, symmetric=True)
+
kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64))
kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True)
diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py
index f2f354604..8e49ccea5 100644
--- a/vllm/model_executor/models/qwen3_moe.py
+++ b/vllm/model_executor/models/qwen3_moe.py
@@ -43,7 +43,6 @@ from vllm.distributed import (
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
-from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
@@ -172,7 +171,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
- routing_method_type=RoutingMethodType.Renormalize,
)
self.gate = ReplicatedLinear(
diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py
index 0f73a7746..e244e6474 100644
--- a/vllm/model_executor/models/qwen3_next.py
+++ b/vllm/model_executor/models/qwen3_next.py
@@ -34,7 +34,6 @@ from vllm.model_executor.layers.fla.ops import (
fused_recurrent_gated_delta_rule,
)
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
-from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3NextRMSNorm,
)
@@ -181,7 +180,6 @@ class Qwen3NextSparseMoeBlock(nn.Module):
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
- routing_method_type=RoutingMethodType.Renormalize,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py
index 212a725d4..21409de86 100644
--- a/vllm/model_executor/warmup/deep_gemm_warmup.py
+++ b/vllm/model_executor/warmup/deep_gemm_warmup.py
@@ -128,11 +128,15 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
"""
Return True if the input module/layer could be processed with DeepGEMM.
"""
+
+ # FIXME: this logic is brittle and incorrect - since we
+ # could use DeepGEMM with for than just Fp8LinearMethod
block_size = get_mk_alignment_for_contiguous_layout()[0]
if not (
isinstance(module, LinearBase)
and isinstance(module.quant_method, Fp8LinearMethod)
and module.quant_method.block_quant
+ and not module.quant_method.use_marlin
):
return False
diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py
index 7a0aff80e..a76d01c5b 100755
--- a/vllm/v1/attention/backends/flashinfer.py
+++ b/vllm/v1/attention/backends/flashinfer.py
@@ -29,7 +29,7 @@ from vllm.model_executor.layers.batch_invariant import (
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
- kNvfp4Quant,
+ kNvfp4Dynamic,
)
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
@@ -1184,7 +1184,7 @@ class FlashInferImpl(AttentionImpl):
return (
self.support_trtllm_attn
and self.kv_cache_dtype.startswith("fp8")
- and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)
+ and quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic)
)
# FlashInfer requires attention sinks to be float32