diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 53f638d50..b5b919c17 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -634,6 +634,46 @@ steps: - pip install helion - pytest -v -s kernels/helion/ + +- label: Kernels FP8 MoE Test (1 H100) + timeout_in_minutes: 90 + gpu: h100 + num_gpus: 1 + optional: true + commands: + - pytest -v -s kernels/moe/test_cutlass_moe.py + - pytest -v -s kernels/moe/test_flashinfer.py + - pytest -v -s kernels/moe/test_gpt_oss_triton_kernels.py + - pytest -v -s kernels/moe/test_modular_oai_triton_moe.py + - pytest -v -s kernels/moe/test_moe.py + # - pytest -v -s kernels/moe/test_block_fp8.py - failing on main + - pytest -v -s kernels/moe/test_block_int8.py + - pytest -v -s kernels/moe/test_triton_moe_no_act_mul.py + - pytest -v -s kernels/moe/test_triton_moe_ptpc_fp8.py + +- label: Kernels FP8 MoE Test (2 H100s) + timeout_in_minutes: 90 + gpu: h100 + num_gpus: 2 + optional: true + commands: + - pytest -v -s kernels/moe/test_deepep_deepgemm_moe.py + - pytest -v -s kernels/moe/test_deepep_moe.py + - pytest -v -s kernels/moe/test_pplx_cutlass_moe.py + # - pytest -v -s kernels/moe/test_pplx_moe.py - failing on main + +- label: Kernels Fp4 MoE Test (B200) + timeout_in_minutes: 60 + gpu: b200 + num_gpus: 1 + optional: true + commands: + - pytest -v -s kernels/moe/test_cutedsl_moe.py + - pytest -v -s kernels/moe/test_flashinfer_moe.py + - pytest -v -s kernels/moe/test_nvfp4_moe.py + - pytest -v -s kernels/moe/test_ocp_mx_moe.py + + - label: Model Executor Test # 23min timeout_in_minutes: 35 torch_nightly: true diff --git a/benchmarks/kernels/benchmark_cutlass_moe_fp8.py b/benchmarks/kernels/benchmark_cutlass_moe_fp8.py index 9c6edee7b..f1234d821 100644 --- a/benchmarks/kernels/benchmark_cutlass_moe_fp8.py +++ b/benchmarks/kernels/benchmark_cutlass_moe_fp8.py @@ -9,6 +9,7 @@ but use different quantization strategies and backends. import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from tests.kernels.moe.utils import make_dummy_moe_config from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 @@ -138,12 +139,13 @@ def bench_run( fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp8( - out_dtype=a.dtype, - e=num_experts, - n=n, - k=k, + moe_config=make_dummy_moe_config( + num_experts=num_experts, + hidden_dim=k, + intermediate_size_per_partition=n, + in_dtype=a.dtype, + ), quant_config=quant_config, - device=w1.device, ), ) diff --git a/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py b/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py index 10a3e3eab..4894d37c4 100644 --- a/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py +++ b/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py @@ -12,6 +12,7 @@ import torch import torch.utils.benchmark as benchmark import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from tests.kernels.moe.utils import make_dummy_moe_config from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.config import ( @@ -198,8 +199,7 @@ def bench_run( kernel = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(defer_input_quant=True), CutlassExpertsFp4( - out_dtype=dtype, - max_experts_per_worker=e, + make_dummy_moe_config(), quant_config=quant_config, ), ) @@ -244,8 +244,7 @@ def bench_run( kernel = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(defer_input_quant=True), CutlassExpertsFp4( - out_dtype=dtype, - max_experts_per_worker=e, + make_dummy_moe_config(), quant_config=quant_config, ), ) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index b30a12638..7b5daa62e 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -6,6 +6,7 @@ import torch.utils.benchmark as benchmark from benchmark_shapes import WEIGHT_SHAPES_MOE import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from tests.kernels.moe.utils import make_dummy_moe_config from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config @@ -134,13 +135,13 @@ def bench_run( fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp8( - out_dtype=a.dtype, - # NOTE(rob): w2 is shaped as [E, hidden, intermediate] - e=w2.shape[0], - n=w2.shape[2], - k=w2.shape[1], + moe_config=make_dummy_moe_config( + num_experts=w2.shape[0], + hidden_dim=w2.shape[1], + intermediate_size_per_partition=w2.shape[2], + in_dtype=a.dtype, + ), quant_config=quant_config, - device=w1.device, ), ) @@ -166,13 +167,13 @@ def bench_run( fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp8( - out_dtype=a.dtype, - # NOTE(rob): w2 is shaped as [E, hidden, intermediate] - e=w2.shape[0], - n=w2.shape[2], - k=w2.shape[1], + moe_config=make_dummy_moe_config( + num_experts=w2.shape[0], + hidden_dim=w2.shape[1], + intermediate_size_per_partition=w2.shape[2], + in_dtype=a.dtype, + ), quant_config=quant_config, - device=w1.device, ), ) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 27a8d9973..90ddee9be 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -16,10 +16,16 @@ import torch from ray.experimental.tqdm_ray import tqdm from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, FusedMoEQuantConfig, + RoutingMethodType, _get_config_dtype_str, ) from vllm.model_executor.layers.fused_moe.fused_moe import * +from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts, +) from vllm.platforms import current_platform from vllm.transformers_utils.config import get_config from vllm.triton_utils import triton @@ -194,10 +200,33 @@ def benchmark_config( block_shape=block_quant_shape, ) + deep_gemm_experts = mk.FusedMoEModularKernel( + prepare_finalize=MoEPrepareAndFinalizeNoEP(), + fused_experts=TritonOrDeepGemmExperts( + moe_config=FusedMoEConfig( + num_experts=num_experts, + experts_per_token=topk, + hidden_dim=hidden_size, + intermediate_size_per_partition=shard_intermediate_size, + num_local_experts=num_experts, + activation="silu", + parallel_config=FusedMoEParallelConfig.make_no_parallel(), + in_dtype=init_dtype, + routing_method=RoutingMethodType.TopK, + ), + quant_config=quant_config, + ), + ) + with override_config(config): topk_weights, topk_ids, token_expert_indices = fused_topk( x, input_gating, topk, renormalize=not use_deep_gemm ) + + if use_deep_gemm: + return deep_gemm_experts( + x, w1, w2, topk_weights, topk_ids, inplace=True + ) return fused_experts( x, w1, @@ -206,7 +235,6 @@ def benchmark_config( topk_ids, inplace=True, quant_config=quant_config, - allow_deep_gemm=use_deep_gemm, ) # JIT compilation & warmup diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 448816c28..022c4f2e8 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -85,10 +85,10 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels |--------|-------------------|--------------|---------------|---------------------|-----------------------|---------|--------| | triton | standard | all1 | G,A,T | silu, gelu,
swigluoai,
silu_no_mul,
gelu_no_mul | Y | Y | [`fused_experts`][vllm.model_executor.layers.fused_moe.fused_moe.fused_experts],
[`TritonExperts`][vllm.model_executor.layers.fused_moe.fused_moe.TritonExperts] | | triton (batched) | batched | all1 | G,A,T | silu, gelu | 6 | Y | [`BatchedTritonExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedTritonExperts] | -| deep gemm | standard,
batched | fp8 | G(128),A,T | silu, gelu | 6 | Y | [`deep_gemm_moe_fp8`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.deep_gemm_moe_fp8],
[`DeepGemmExperts`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.DeepGemmExperts],
[`BatchedDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe.BatchedDeepGemmExperts] | +| deep gemm | standard,
batched | fp8 | G(128),A,T | silu, gelu | 6 | Y |
[`DeepGemmExperts`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.DeepGemmExperts],
[`BatchedDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe.BatchedDeepGemmExperts] | | cutlass_fp4 | standard,
batched | nvfp4 | A,T | silu | Y | Y | [`CutlassExpertsFp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp4] | | cutlass_fp8 | standard,
batched | fp8 | A,T | silu, gelu | Y | Y | [`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],
[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] | -| flashinfer | standard | nvfp4,
fp8 | T | 5 | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],
[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] | +| flashinfer | standard | nvfp4,
fp8 | T | 5 | N | Y | [`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] | | gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],
[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] | | marlin | standard,
batched | 3 / N/A | 3 / N/A | silu,
swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],
[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],
[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] | | trtllm | standard | mxfp4,
nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 1f1dafcca..c6f245669 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -43,7 +43,7 @@ from vllm.forward_context import get_forward_context, set_forward_context from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, - kNvfp4Quant, + kNvfp4Dynamic, ) from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer @@ -215,7 +215,7 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel): class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): """Test model for AttentionNvfp4QuantPattern fusion.""" - quant_key = kNvfp4Quant + quant_key = kNvfp4Dynamic def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -468,7 +468,7 @@ def test_attention_quant_pattern( # Note: for fp8, fully_replaced=False because query quant ops remain in graph. # Only output quant ops are fused into attention. - test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Quant) + test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Dynamic) # access the underlying `AttnFusionPass` on the `LazyInitPass` assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index b71487275..f8ae83ca3 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -44,7 +44,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, kFp8StaticTensorSym, - kNvfp4Quant, + kNvfp4Dynamic, ) from vllm.platforms import current_platform @@ -134,11 +134,11 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module): def ops_in_model_before(self): return [ SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul, - QUANT_OPS[kNvfp4Quant], + QUANT_OPS[kNvfp4Dynamic], ] def ops_in_model_after(self): - return [FUSED_OPS[kNvfp4Quant]] + return [FUSED_OPS[kNvfp4Dynamic]] class TestSiluMulGroupFp8QuantModel(torch.nn.Module): diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Llama-4-Scout-Fp8-ModelOpt-triton.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Llama-4-Scout-Fp8-ModelOpt-triton.yaml index 9d62c542a..9e13797bb 100644 --- a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Llama-4-Scout-Fp8-ModelOpt-triton.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Llama-4-Scout-Fp8-ModelOpt-triton.yaml @@ -3,3 +3,6 @@ accuracy_threshold: 0.92 num_questions: 1319 num_fewshot: 5 server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel" +env: + VLLM_USE_FLASHINFER_MOE_FP8: "0" + VLLM_USE_DEEP_GEMM: "0" diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ll.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ll.yaml index 54e6ab7b3..5f10684d2 100644 --- a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ll.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ll.yaml @@ -6,4 +6,3 @@ server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enab env: VLLM_USE_DEEP_GEMM: "1" VLLM_USE_DEEP_GEMM_MOE: "1" - VLLM_USE_DEEP_GEMM_E8M0: "0" diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ll.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ll.yaml index 1d4cbfe96..37e6039e9 100644 --- a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ll.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ll.yaml @@ -6,4 +6,3 @@ server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enab env: VLLM_USE_DEEP_GEMM: "1" VLLM_USE_DEEP_GEMM_MOE: "1" - VLLM_USE_DEEP_GEMM_E8M0: "0" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-ModelOpt-triton.yaml b/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-ModelOpt-triton.yaml index 80e279edc..ae6bf6755 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-ModelOpt-triton.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-ModelOpt-triton.yaml @@ -3,3 +3,5 @@ accuracy_threshold: 0.92 num_questions: 1319 num_fewshot: 5 server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" +env: + VLLM_USE_FLASHINFER_MOE_FP8: "0" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-cutlass.yaml index 080c8d338..74820cd28 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-cutlass.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-cutlass.yaml @@ -4,7 +4,5 @@ num_questions: 1319 num_fewshot: 5 server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" env: - VLLM_USE_DEEP_GEMM: "0" - VLLM_USE_DEEP_GEMM_MOE: "0" VLLM_USE_FLASHINFER_MOE_FP8: "1" VLLM_FLASHINFER_MOE_BACKEND: "throughput" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-trtllm.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-trtllm.yaml index a656cc7c3..d745c9b5b 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-trtllm.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-trtllm.yaml @@ -4,7 +4,5 @@ num_questions: 1319 num_fewshot: 5 server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" env: - VLLM_USE_DEEP_GEMM: "0" - VLLM_USE_DEEP_GEMM_MOE: "0" VLLM_USE_FLASHINFER_MOE_FP8: "1" VLLM_FLASHINFER_MOE_BACKEND: "latency" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-marlin.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-marlin.yaml index f2273bf2c..c3d86e6bf 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-marlin.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-marlin.yaml @@ -4,6 +4,4 @@ num_questions: 1319 num_fewshot: 5 server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" env: - VLLM_USE_DEEP_GEMM: "0" - VLLM_USE_DEEP_GEMM_MOE: "0" VLLM_TEST_FORCE_FP8_MARLIN: "1" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-triton.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-triton.yaml index ed61e9b89..1b2d72160 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-triton.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-triton.yaml @@ -4,5 +4,5 @@ num_questions: 1319 num_fewshot: 5 server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" env: + VLLM_USE_FLASHINFER_MOE_FP8: "0" VLLM_USE_DEEP_GEMM: "0" - VLLM_USE_DEEP_GEMM_MOE: "0" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-fi-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-fi-cutlass.yaml index db18dd01b..48ab58c46 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-fi-cutlass.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-fi-cutlass.yaml @@ -4,7 +4,5 @@ num_questions: 1319 num_fewshot: 5 server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" env: - VLLM_USE_DEEP_GEMM: "0" - VLLM_USE_DEEP_GEMM_MOE: "0" VLLM_USE_FLASHINFER_MOE_FP8: "1" VLLM_FLASHINFER_MOE_BACKEND: "throughput" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-marlin.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-marlin.yaml index 3d82d2e22..46eee7421 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-marlin.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-marlin.yaml @@ -4,6 +4,4 @@ num_questions: 1319 num_fewshot: 5 server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" env: - VLLM_USE_DEEP_GEMM: "0" - VLLM_USE_DEEP_GEMM_MOE: "0" VLLM_TEST_FORCE_FP8_MARLIN: "1" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-triton.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-triton.yaml index 5621217de..3e30d4d15 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-triton.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-triton.yaml @@ -4,5 +4,5 @@ num_questions: 1319 num_fewshot: 5 server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" env: + VLLM_USE_FLASHINFER_MOE_FP8: "0" VLLM_USE_DEEP_GEMM: "0" - VLLM_USE_DEEP_GEMM_MOE: "0" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-CT-vllm-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-CT-vllm-cutlass.yaml index b1ccadedd..0d7884928 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-CT-vllm-cutlass.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-CT-vllm-cutlass.yaml @@ -3,3 +3,5 @@ accuracy_threshold: 0.88 num_questions: 1319 num_fewshot: 5 server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" +env: + VLLM_USE_FLASHINFER_MOE_FP4: "0" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-vllm-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-vllm-cutlass.yaml index 49a1589fc..a340b6fda 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-vllm-cutlass.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-vllm-cutlass.yaml @@ -3,3 +3,5 @@ accuracy_threshold: 0.88 num_questions: 1319 num_fewshot: 5 server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" +env: + VLLM_USE_FLASHINFER_MOE_FP4: "0" diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index ac6d54b71..350e1be2f 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -26,6 +26,7 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, + RoutingMethodType, ) from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx @@ -574,10 +575,14 @@ def make_modular_kernel( num_experts=config.E, experts_per_token=config.topk, hidden_dim=config.K, + intermediate_size_per_partition=config.N, num_local_experts=config.num_local_experts, moe_parallel_config=moe_parallel_config, in_dtype=config.dtype, max_num_tokens=next_power_of_2(config.M), + activation="silu", + device=vllm_config.device_config.device, + routing_method=RoutingMethodType.DeepSeekV3, ) # make modular kernel diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index 99b168dc7..04a654e89 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -425,84 +425,26 @@ def make_fused_experts( num_dispatchers: int, N: int, ) -> mk.FusedMoEPermuteExpertsUnpermute: - batch_kwargs = { - "max_num_tokens": moe.max_num_tokens, - "num_dispatchers": num_dispatchers, - } - quant_kwargs = { - "quant_config": quant_config, - } - deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()} + if ( + fused_experts_type.activation_format() + == mk.FusedMoEActivationFormat.BatchedExperts + ): + kwargs = { + "moe_config": moe, + "quant_config": quant_config, + "max_num_tokens": moe.max_num_tokens, + "num_dispatchers": num_dispatchers, + } + else: + kwargs = { + "moe_config": moe, + "quant_config": quant_config, + } torch.set_printoptions(threshold=0, edgeitems=0, linewidth=10000) - if fused_experts_type == BatchedDeepGemmExperts: - kwargs = batch_kwargs | quant_kwargs - print(f"Making BatchedDeepGemmExperts {kwargs} ...") - experts = BatchedDeepGemmExperts(**kwargs) - elif fused_experts_type == BatchedTritonExperts: - kwargs = batch_kwargs | quant_kwargs - print(f"Making BatchedTritonExperts {kwargs} ...") - experts = BatchedTritonExperts(**kwargs) - elif fused_experts_type == DeepGemmExperts: - print(f"Making DeepGemmExperts {quant_config} ...") - experts = DeepGemmExperts(quant_config) - elif fused_experts_type == TritonExperts: - kwargs = quant_kwargs - print(f"Making TritonExperts {kwargs} ...") - experts = TritonExperts(**kwargs) - elif fused_experts_type == TritonOrDeepGemmExperts: - kwargs = quant_kwargs | deepgemm_kwargs - print(f"Making TritonOrDeepGemmExperts {kwargs} ...") - experts = TritonOrDeepGemmExperts(**kwargs) - elif fused_experts_type == NaiveBatchedExperts: - kwargs = batch_kwargs | quant_kwargs - print(f"Making NaiveBatchedExperts {kwargs} ...") - experts = NaiveBatchedExperts(**kwargs) - elif fused_experts_type == CutlassExpertsFp8: - strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim) - kwargs = { - "out_dtype": moe.in_dtype, - "ab_strides1": strides[0], - "ab_strides2": strides[1], - "c_strides1": strides[2], - "c_strides2": strides[3], - } | quant_kwargs - print(f"Making CutlassExpertsFp8 {kwargs} ...") - experts = CutlassExpertsFp8(**kwargs) - elif fused_experts_type == CutlassBatchedExpertsFp8: - strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim) - kwargs = { - "max_experts_per_worker": moe.num_local_experts, - "num_dispatchers": num_dispatchers, - "out_dtype": moe.in_dtype, - "ab_strides1": strides[0], - "ab_strides2": strides[1], - "c_strides1": strides[2], - "c_strides2": strides[3], - } | quant_kwargs - print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...") - experts = CutlassBatchedExpertsFp8(**kwargs) - elif fused_experts_type == CutlassExpertsFp4: - kwargs = { - "max_experts_per_worker": moe.num_local_experts, - "num_dispatchers": num_dispatchers, - "out_dtype": moe.in_dtype, - } | quant_kwargs - print(f"Making CutlassExpertsFp4 {kwargs} ...") - experts = CutlassExpertsFp4(**kwargs) - elif fused_experts_type == FlashInferExperts: - kwargs = { - "out_dtype": moe.in_dtype, - "ep_rank": moe.ep_rank, - "ep_size": moe.ep_size, - "tp_rank": moe.tp_rank, - "tp_size": moe.tp_size, - } | quant_kwargs - print(f"Making FlashInferExperts {kwargs} ...") - experts = FlashInferExperts(**kwargs) - else: - raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}") + print(f"Making {fused_experts_type.__class__.__name__} {kwargs} ...") + experts = fused_experts_type(**kwargs) torch.set_printoptions(threshold=1000, edgeitems=5, linewidth=80) diff --git a/tests/kernels/moe/test_batched_deepgemm.py b/tests/kernels/moe/test_batched_deepgemm.py index 0ba3d8d4c..081a5fd0b 100644 --- a/tests/kernels/moe/test_batched_deepgemm.py +++ b/tests/kernels/moe/test_batched_deepgemm.py @@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularK from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported from .test_deepgemm import make_block_quant_fp8_weights +from .utils import make_dummy_moe_config BLOCK_SIZE = [128, 128] @@ -71,6 +72,7 @@ def test_batched_deepgemm_vs_triton( max_num_tokens=max_num_tokens, num_dispatchers=1, quant_config=quant_config, + moe_config=make_dummy_moe_config(), ) mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts) @@ -89,6 +91,7 @@ def test_batched_deepgemm_vs_triton( max_num_tokens=max_num_tokens, num_dispatchers=1, quant_config=quant_config, + moe_config=make_dummy_moe_config(), ) mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts) diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index 63fbbfeec..508df9e32 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -4,7 +4,12 @@ import pytest import torch -from tests.kernels.moe.utils import make_test_quant_config, make_test_weights +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from tests.kernels.moe.utils import ( + make_dummy_moe_config, + make_test_quant_config, + make_test_weights, +) from tests.kernels.quant_utils import ( native_per_token_group_quant_fp8, native_w8a8_block_matmul, @@ -15,13 +20,21 @@ from vllm.model_executor.layers.fused_moe import ( fused_experts, fused_topk, ) +from vllm.model_executor.layers.fused_moe.config import ( + fp8_w8a8_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( _valid_deep_gemm_shape, - deep_gemm_moe_fp8, ) from vllm.model_executor.layers.fused_moe.fused_moe import ( modular_triton_fused_moe, ) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) +from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts, +) from vllm.platforms import current_platform from vllm.utils.deep_gemm import ( get_mk_alignment_for_contiguous_layout, @@ -161,7 +174,7 @@ def test_w8a8_block_fp8_fused_moe( block_shape=block_size, ) - m_fused_moe = modular_triton_fused_moe(quant_config) + m_fused_moe = modular_triton_fused_moe(make_dummy_moe_config(), quant_config) topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False) @@ -236,6 +249,29 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch) topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False) + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + + deep_gemm_experts = mk.FusedMoEModularKernel( + prepare_finalize=MoEPrepareAndFinalizeNoEP(), + fused_experts=TritonOrDeepGemmExperts( + moe_config=make_dummy_moe_config(), + quant_config=quant_config, + ), + ) + + def deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids): + return deep_gemm_experts( + hidden_states=a, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + ) + # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): ref_out = torch_w8a8_block_fp8_moe( diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index d5b1c2cf0..3a5a66a38 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -8,6 +8,7 @@ import pytest import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from tests.kernels.moe.utils import make_dummy_moe_config from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk @@ -193,16 +194,18 @@ def run_with_expert_maps( out_tensor = torch.zeros_like(cutlass_moe_kwargs["hidden_states"]) for kwargs, new_quant_config in slice_experts(): + w2 = kwargs["w2"] + a = kwargs["hidden_states"] kernel = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp8( - out_dtype=kwargs["hidden_states"].dtype, - # NOTE(rob): w2 is shaped as [E, hidden, intermediate] - e=kwargs["w2"].shape[0], # type: ignore[union-attr] - n=kwargs["w2"].shape[2], # type: ignore[union-attr] - k=kwargs["w2"].shape[1], # type: ignore[union-attr] + moe_config=make_dummy_moe_config( + num_experts=w2.shape[0], + hidden_dim=w2.shape[1], + intermediate_size_per_partition=w2.shape[2], + in_dtype=a.dtype, + ), quant_config=new_quant_config, - device="cuda", ), ) out_tensor = out_tensor + kernel(**kwargs) @@ -249,19 +252,19 @@ def run_8_bit( "topk_ids": topk_ids, } - num_experts = moe_tensors.w1.size(0) + num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined] with_ep = num_local_experts is not None or num_local_experts == num_experts if not with_ep: kernel = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp8( - out_dtype=moe_tensors.a.dtype, - # NOTE(rob): w2 is shaped as [E, hidden, intermediate] - e=moe_tensors.w2_q.shape[0], # type: ignore[union-attr] - n=moe_tensors.w2_q.shape[2], # type: ignore[union-attr] - k=moe_tensors.w2_q.shape[1], # type: ignore[union-attr] + moe_config=make_dummy_moe_config( + num_experts=moe_tensors.w2_q.shape[0], # type: ignore[union-attr] + hidden_dim=moe_tensors.w2_q.shape[1], # type: ignore[union-attr] + intermediate_size_per_partition=moe_tensors.w2_q.shape[2], # type: ignore[union-attr] + in_dtype=moe_tensors.a.dtype, + ), quant_config=quant_config, - device="cuda", ), ) return kernel(**kwargs) diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 8987b688a..1bf5ced2e 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -33,7 +33,7 @@ from vllm.v1.worker.workspace import init_workspace_manager from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch -from .utils import make_test_weights +from .utils import make_dummy_moe_config, make_test_weights if has_deep_ep(): from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( @@ -192,6 +192,7 @@ def make_ll_modular_kernel( max_num_tokens=max_tokens_per_rank, num_dispatchers=pgi.world_size // dp_size, quant_config=quant_config, + moe_config=make_dummy_moe_config(), ) mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) return mk @@ -219,7 +220,10 @@ def make_ht_modular_kernel( block_shape=test_config.block_size, ) - fused_experts = DeepGemmExperts(quant_config) + fused_experts = DeepGemmExperts( + moe_config=make_dummy_moe_config(), + quant_config=quant_config, + ) mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) return mk @@ -349,9 +353,6 @@ def triton_impl( topk_ids=topk_ids, inplace=False, quant_config=quant_config, - # Make sure this is set to False so we - # don't end up comparing the same implementation. - allow_deep_gemm=False, ) diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index e57e0d720..f740f5bf9 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -10,11 +10,14 @@ import pytest import torch.distributed from torch.distributed import ProcessGroup +from tests.kernels.moe.utils import make_dummy_moe_config from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import TritonExperts -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -160,15 +163,21 @@ def make_modular_kernel( num_dispatchers = pgi.world_size // dp_size + moe_config = make_dummy_moe_config() + if low_latency_mode: assert not quant_config.per_act_token_quant, "not supported in ll mode" fused_experts = BatchedTritonExperts( max_num_tokens=MAX_TOKENS_PER_RANK, num_dispatchers=num_dispatchers, + moe_config=moe_config, quant_config=quant_config, ) else: - fused_experts = TritonExperts(quant_config=quant_config) + fused_experts = TritonExperts( + moe_config=moe_config, + quant_config=quant_config, + ) mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) return mk diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index 442b561f8..729b54753 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -11,10 +11,19 @@ import math import pytest import torch -from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config - # vLLM fused-expert reference (Triton fallback + DeepGEMM option) +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from tests.kernels.moe.utils import make_dummy_moe_config +from vllm.model_executor.layers.fused_moe.config import ( + fp8_w8a8_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) +from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts, +) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, ) @@ -100,6 +109,14 @@ def run_single_case(m, n, k, topk, num_experts, block_size): block_shape=block_size, ) + deep_gemm_experts = mk.FusedMoEModularKernel( + prepare_finalize=MoEPrepareAndFinalizeNoEP(), + fused_experts=TritonOrDeepGemmExperts( + moe_config=make_dummy_moe_config(), + quant_config=quant_config, + ), + ) + # triton reference out_triton = fused_experts( hidden_states=tokens_bf16, @@ -109,19 +126,16 @@ def run_single_case(m, n, k, topk, num_experts, block_size): topk_ids=topk_ids, inplace=False, quant_config=quant_config, - allow_deep_gemm=False, ) # DeepGemm - out_deepgemm = fused_experts( + out_deepgemm = deep_gemm_experts( hidden_states=tokens_bf16, w1=w1, w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, inplace=False, - quant_config=quant_config, - allow_deep_gemm=True, ) diff = calc_diff(out_deepgemm, out_triton) assert diff < 0.001, f"Diff exceeded 1%: {diff}" @@ -147,20 +161,19 @@ def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch, workspace_i with monkeypatch.context() as mp: mp.setenv("VLLM_USE_DEEP_GEMM", "1") - _fused_moe_mod = importlib.import_module( - "vllm.model_executor.layers.fused_moe.fused_moe" - ) + _DeepGemmExperts = importlib.import_module( + "vllm.model_executor.layers.fused_moe.deep_gemm_moe" + ).DeepGemmExperts call_counter = {"cnt": 0} - orig_fn = _fused_moe_mod.deep_gemm_moe_fp8 + orig_fn = _DeepGemmExperts.apply - def _spy_deep_gemm_moe_fp8(*args, **kwargs): + def _spy_apply(*args, **kwargs): call_counter["cnt"] += 1 return orig_fn(*args, **kwargs) - monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", _spy_deep_gemm_moe_fp8) - + monkeypatch.setattr(_DeepGemmExperts, "apply", _spy_apply) if topk > num_experts: pytest.skip(f"topk={topk} > num_experts={num_experts}") diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index bb2f6b873..e5752e404 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -8,7 +8,10 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, FusedMoEQuantConfig, + RoutingMethodType, fp8_w8a8_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( @@ -116,18 +119,7 @@ class TestData: layer.w13_weight_scale = w13_weight_scale layer.w2_weight_scale = w2_weight_scale # Setup dummy config. - layer.moe_parallel_config = mk.FusedMoEParallelConfig( - tp_size=1, - pcp_size=1, - dp_size=1, - ep_size=1, - tp_rank=0, - pcp_rank=0, - dp_rank=0, - ep_rank=0, - use_ep=False, - all2all_backend="naive", - ) + layer.moe_parallel_config = mk.FusedMoEParallelConfig.make_no_parallel() # flashinfer expects swapped rows for w13 layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) @@ -238,6 +230,8 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( ): set_random_seed(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") + assert activation in ["silu", "relu2_no_mul"] + is_act_and_mul = activation == "silu_and_mul" with set_current_vllm_config(vllm_config): td = TestData.make_moe_tensors_8bit( m, k, n, e, is_trtllm=False, activation=activation @@ -285,19 +279,30 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config td.layer.quant_method = td.layer + moe_config = FusedMoEConfig( + num_experts=e, + experts_per_token=topk, + hidden_dim=k, + intermediate_size_per_partition=n, + num_local_experts=e, + activation=activation, + device="cuda", + moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), + in_dtype=torch.bfloat16, + is_act_and_mul=is_act_and_mul, + routing_method=RoutingMethodType.TopK, + ) + kernel = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP( - defer_input_quant=quant_config.is_block_quantized + defer_input_quant=FlashInferExperts.expects_unquantized_inputs( + moe_config=moe_config, + quant_config=quant_config, + ) ), FlashInferExperts( - out_dtype=td.layer.orig_dtype, + moe_config=moe_config, quant_config=quant_config, - ep_rank=td.layer.moe_parallel_config.ep_rank, - ep_size=td.layer.moe_parallel_config.ep_size, - tp_rank=td.layer.moe_parallel_config.tp_rank, - tp_size=td.layer.moe_parallel_config.tp_size, - use_dp=False, - use_deepseek_fp8_block_scale=False, ), ) diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py index ee23e7e65..66ceb0c64 100644 --- a/tests/kernels/moe/test_flashinfer_moe.py +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -13,14 +13,19 @@ from tests.kernels.utils import torch_moe from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + RoutingMethodType, +) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe, ) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( - create_flashinfer_prepare_finalize, -) from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.torch_utils import set_random_seed @@ -86,9 +91,28 @@ def test_flashinfer_fp4_moe_no_graph( assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q) + moe_config = FusedMoEConfig( + num_experts=e, + experts_per_token=topk, + hidden_dim=k, + intermediate_size_per_partition=n, + num_local_experts=e, + activation=activation, + device="cuda", + moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), + in_dtype=dtype, + is_act_and_mul=is_gated_act, + routing_method=RoutingMethodType.TopK, + ) + flashinfer_experts = FusedMoEModularKernel( - create_flashinfer_prepare_finalize(use_dp=False, use_nvfp4=True), - FlashInferExperts(out_dtype=dtype, quant_config=quant_config), + MoEPrepareAndFinalizeNoEP( + defer_input_quant=FlashInferExperts.expects_unquantized_inputs( + moe_config=moe_config, + quant_config=quant_config, + ) + ), + FlashInferExperts(moe_config=moe_config, quant_config=quant_config), ) fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation] diff --git a/tests/kernels/moe/test_modular_oai_triton_moe.py b/tests/kernels/moe/test_modular_oai_triton_moe.py index 8733ba4d8..38022e0e6 100644 --- a/tests/kernels/moe/test_modular_oai_triton_moe.py +++ b/tests/kernels/moe/test_modular_oai_triton_moe.py @@ -36,6 +36,8 @@ from vllm.model_executor.layers.utils import shuffle_weight from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed +from .utils import make_dummy_moe_config + MNK = [ (1, 512, 384), (1, 2880, 2880), @@ -174,9 +176,9 @@ def oai_triton_moe_impl( ) if unfused: - fused_experts = UnfusedOAITritonExperts(quant_config) + fused_experts = UnfusedOAITritonExperts(make_dummy_moe_config(), quant_config) else: - fused_experts = OAITritonExperts(quant_config) + fused_experts = OAITritonExperts(make_dummy_moe_config(), quant_config) mk = FusedMoEModularKernel(MoEPrepareAndFinalizeNoEP(), fused_experts) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 849f9e7f1..e34c78074 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -18,7 +18,7 @@ from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock import vllm.model_executor.layers.fused_moe # noqa -from tests.kernels.moe.utils import fused_moe +from tests.kernels.moe.utils import fused_moe, make_dummy_moe_config from tests.kernels.utils import opcheck, stack_and_dev, torch_experts, torch_moe from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig, set_current_vllm_config @@ -332,7 +332,7 @@ def test_fused_moe( # quant_config = FUSED_MOE_UNQUANTIZED_CONFIG - m_fused_moe_fn = modular_triton_fused_moe(quant_config) + m_fused_moe_fn = modular_triton_fused_moe(make_dummy_moe_config(), quant_config) def m_fused_moe( a: torch.Tensor, @@ -437,7 +437,7 @@ def test_naive_block_assignment_moe( # quant_config = FUSED_MOE_UNQUANTIZED_CONFIG - m_fused_moe_fn = modular_triton_fused_moe(quant_config) + m_fused_moe_fn = modular_triton_fused_moe(make_dummy_moe_config(), quant_config) def m_fused_moe( a: torch.Tensor, diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index 4dd4223db..0011149e3 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -4,7 +4,7 @@ import pytest import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from tests.kernels.moe.utils import make_test_weights +from tests.kernels.moe.utils import make_dummy_moe_config, make_test_weights from tests.kernels.quantization.nvfp4_utils import ( FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, @@ -92,8 +92,7 @@ def test_cutlass_fp4_moe_no_graph( kernel = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(defer_input_quant=True), CutlassExpertsFp4( - out_dtype=dtype, - max_experts_per_worker=e, + moe_config=make_dummy_moe_config(), quant_config=quant_config, ), ) diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index a2ab94e37..ef37c1c74 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -9,12 +9,18 @@ from tests.kernels.utils import torch_experts from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_topk -from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + RoutingMethodType, + fp8_w8a8_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassBatchedExpertsFp8 from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import set_random_seed +from vllm.v1.worker.workspace import init_workspace_manager from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch @@ -79,6 +85,8 @@ def pplx_cutlass_moe( PplxPrepareAndFinalize, ) + init_workspace_manager(torch.cuda.current_device()) + assert torch.cuda.current_device() == pgi.local_rank num_tokens, hidden_dim = a.shape @@ -132,28 +140,23 @@ def pplx_cutlass_moe( num_dispatchers=num_dispatchers, ) - ab_strides1 = torch.full( - (num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64 - ) - ab_strides2 = torch.full( - (num_local_experts,), intermediate_dim, device="cuda", dtype=torch.int64 - ) - c_strides1 = torch.full( - (num_local_experts,), 2 * intermediate_dim, device="cuda", dtype=torch.int64 - ) - c_strides2 = torch.full( - (num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64 - ) + def make_moe_config() -> FusedMoEConfig: + return FusedMoEConfig( + num_experts=num_experts, + experts_per_token=topk, + hidden_dim=hidden_dim, + intermediate_size_per_partition=intermediate_dim, + num_local_experts=num_local_experts, + moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), + activation="silu", + in_dtype=torch.bfloat16, + device="cuda", + routing_method=RoutingMethodType.Llama4, + ) experts = CutlassBatchedExpertsFp8( - num_local_experts, - num_dispatchers, - out_dtype, - ab_strides1, - ab_strides2, - c_strides1, - c_strides2, - fp8_w8a8_moe_quant_config( + moe_config=make_moe_config(), + quant_config=fp8_w8a8_moe_quant_config( per_act_token_quant=per_act_token, per_out_ch_quant=per_out_ch, w1_scale=chunk_by_rank(w1_scale, rank, world_size), @@ -162,6 +165,8 @@ def pplx_cutlass_moe( if per_act_token else a1_scale[rank], ), + max_num_tokens=max_num_tokens, + num_dispatchers=num_dispatchers, ) fused_cutlass_experts = FusedMoEModularKernel( diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index c08a54f0e..08519087e 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -29,6 +29,7 @@ except ImportError: from tests.kernels.moe.modular_kernel_tools.parallel_utils import _set_vllm_config from tests.kernels.moe.utils import ( + make_dummy_moe_config, make_shared_experts, make_test_weights, naive_batched_moe, @@ -584,6 +585,7 @@ def pplx_moe( max_num_tokens=max_num_tokens, num_dispatchers=prepare_finalize.num_dispatchers(), quant_config=quant_config, + moe_config=make_dummy_moe_config(), ) fused_experts = FusedMoEModularKernel( diff --git a/tests/kernels/moe/test_routing.py b/tests/kernels/moe/test_routing.py index 93aa6aa5c..f623f943f 100644 --- a/tests/kernels/moe/test_routing.py +++ b/tests/kernels/moe/test_routing.py @@ -6,7 +6,6 @@ import pytest import torch from vllm.distributed.eplb.eplb_state import EplbLayerState -from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.fused_moe.router.router_factory import ( create_fused_moe_router, ) @@ -385,17 +384,11 @@ def test_grouped_topk( global_num_experts, ) - routing_method_type = None - if scoring_func == "llama4": - routing_method_type = RoutingMethodType.Llama4 - scoring_func = "sigmoid" - router = create_fused_moe_router( use_grouped_topk=True, num_expert_group=num_expert_group, topk_group=topk_group, scoring_func=scoring_func, - routing_method_type=routing_method_type, e_score_correction_bias=e_score_correction_bias, routed_scaling_factor=routed_scaling_factor, top_k=top_k, diff --git a/tests/kernels/moe/test_triton_moe_no_act_mul.py b/tests/kernels/moe/test_triton_moe_no_act_mul.py index 12d5180f9..ab15f898b 100644 --- a/tests/kernels/moe/test_triton_moe_no_act_mul.py +++ b/tests/kernels/moe/test_triton_moe_no_act_mul.py @@ -10,6 +10,7 @@ equals N (not N // 2 like gated activations). import pytest import torch +from tests.kernels.moe.utils import make_dummy_moe_config from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, ) @@ -78,7 +79,10 @@ def test_triton_experts_no_mul_activation( m, n, k, NUM_EXPERTS, topk ) - experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG) + experts = TritonExperts( + moe_config=make_dummy_moe_config(), + quant_config=FUSED_MOE_UNQUANTIZED_CONFIG, + ) ws1_shape, ws2_shape, out_shape = experts.workspace_shapes( M=m, @@ -151,7 +155,10 @@ def test_workspace_shapes_no_mul_vs_gated(): M, N, K, topk = 64, 256, 128, 2 - experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG) + experts = TritonExperts( + moe_config=make_dummy_moe_config(), + quant_config=FUSED_MOE_UNQUANTIZED_CONFIG, + ) ws1_no_mul, _, out_no_mul = experts.workspace_shapes( M, N, K, topk, 8, 8, None, SILU_NO_MUL @@ -187,7 +194,10 @@ def test_adjust_n_for_activation(): """Test the adjust_N_for_activation method.""" from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts - experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG) + experts = TritonExperts( + moe_config=make_dummy_moe_config(), + quant_config=FUSED_MOE_UNQUANTIZED_CONFIG, + ) N = 256 diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index f0c8c8033..4883085cb 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -8,7 +8,12 @@ from tests.kernels.quant_utils import per_block_cast_to_int8 from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, + RoutingMethodType, +) from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedPrepareAndFinalize, BatchedTritonExperts, @@ -20,6 +25,34 @@ from vllm.utils.deep_gemm import per_block_cast_to_fp8 from vllm.utils.math_utils import round_up +def make_dummy_moe_config( + num_experts: int = 1, + experts_per_token: int = 1, + hidden_dim: int = 1, + intermediate_size_per_partition: int = 1, + in_dtype: torch.dtype = torch.bfloat16, +) -> FusedMoEConfig: + """ + This is a dummy config for the mk constructor interface + as most kernels like DeepGEMM, CUTLASSFp4, Triton, MARLIN + do not actually use this config. + + CUTLASSFp8 needs to set some params for workshapes. + """ + return FusedMoEConfig( + num_experts=num_experts, + experts_per_token=experts_per_token, + hidden_dim=hidden_dim, + intermediate_size_per_partition=intermediate_size_per_partition, + num_local_experts=num_experts, + moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), + activation="silu", + in_dtype=in_dtype, + device="cuda", + routing_method=RoutingMethodType.TopK, + ) + + def triton_moe( a: torch.Tensor, w1: torch.Tensor, @@ -81,6 +114,7 @@ def batched_moe( max_num_tokens=max_num_tokens, num_dispatchers=1, quant_config=quant_config, + moe_config=make_dummy_moe_config(), ), ) @@ -121,6 +155,7 @@ def naive_batched_moe( max_num_tokens=max_num_tokens, num_dispatchers=1, quant_config=quant_config, + moe_config=make_dummy_moe_config(), ), ) diff --git a/tools/vllm-rocm/pin_rocm_dependencies.py b/tools/vllm-rocm/pin_rocm_dependencies.py index ba11fd934..b9387069d 100644 --- a/tools/vllm-rocm/pin_rocm_dependencies.py +++ b/tools/vllm-rocm/pin_rocm_dependencies.py @@ -11,10 +11,11 @@ This ensures that 'pip install vllm' automatically installs the correct custom w instead of allowing pip to download different versions from PyPI. """ -import re import sys from pathlib import Path +import regex as re + def extract_version_from_wheel(wheel_name: str) -> str: """ diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index f0ce5b3db..8530c0dad 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -18,7 +18,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, - kNvfp4Quant, + kNvfp4Dynamic, ) from vllm.platforms import current_platform @@ -41,7 +41,7 @@ silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr( torch.ops._C, "silu_and_mul_nvfp4_quant" ) if silu_and_mul_nvfp4_quant_supported: - FUSED_OPS[kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501 + FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501 class ActivationQuantPattern(ABC): @@ -129,7 +129,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern): """ def __init__(self) -> None: - super().__init__(kNvfp4Quant) + super().__init__(kNvfp4Dynamic) def get_inputs(self) -> list[torch.Tensor]: result = self.empty_quant(5, 32) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index e3c6c2f20..667828cc6 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8DynamicTensorSym, kFp8DynamicTokenSym, kFp8StaticTensorSym, - kNvfp4Quant, + kNvfp4Dynamic, kStaticTensorScale, ) from vllm.platforms import current_platform @@ -63,7 +63,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = { kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): - QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default + QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default if current_platform.is_cuda(): QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index 57448aa0b..69dc2e3a6 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -16,7 +16,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, - kNvfp4Quant, + kNvfp4Dynamic, kStaticTensorScale, ) from vllm.platforms import current_platform @@ -217,7 +217,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): """ def __init__(self, layer: Attention, dtype: torch.dtype) -> None: - super().__init__(layer, kNvfp4Quant, dtype) + super().__init__(layer, kNvfp4Dynamic, dtype) def _register(self, pm_pass: PatternMatcherPass) -> None: def pattern( diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index eda12180d..7bb98db5e 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -21,7 +21,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8DynamicTensorSym, kFp8DynamicTokenSym, kFp8StaticTensorSym, - kNvfp4Quant, + kNvfp4Dynamic, ) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform @@ -38,7 +38,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = { } if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): - QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 + QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 if current_platform.is_cuda(): QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index e598ec3ac..ac37cff93 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -7,11 +7,20 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, ) from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8Dynamic128Sym, + kFp8Static128BlockSym, +) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import ( @@ -19,6 +28,7 @@ from vllm.utils.deep_gemm import ( fp8_m_grouped_gemm_nt_masked, get_mk_alignment_for_contiguous_layout, is_deep_gemm_e8m0_used, + is_deep_gemm_supported, ) from vllm.utils.math_utils import cdiv, round_up @@ -253,29 +263,52 @@ def persistent_masked_m_silu_mul_quant( class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, max_num_tokens: int, num_dispatchers: int, - quant_config: FusedMoEQuantConfig, ): """ max_num_tokens: Maximum number of tokens from a DP Rank num_dispatchers: The number of DP dispatchers. quant_config: Quantization configuration """ - super().__init__(quant_config) + super().__init__( + moe_config=moe_config, + quant_config=quant_config, + max_num_tokens=max_num_tokens, + num_dispatchers=num_dispatchers, + ) assert self.block_shape == get_mk_alignment_for_contiguous_layout() assert self.quant_config.use_fp8_w8a8 - self.max_num_tokens = max_num_tokens - self.num_dispatchers = num_dispatchers - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts + + @staticmethod + def _supports_current_device() -> bool: + return is_deep_gemm_supported() + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + SUPPORTED_W_A = [(kFp8Static128BlockSym, kFp8Dynamic128Sym)] + return (weight_key, activation_key) in SUPPORTED_W_A + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["silu"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True def supports_chunking(self) -> bool: return False @@ -310,6 +343,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): # FIXME (varun): We should be able to dispatch only from the leader # DP ranks in the case of TP > 1. At the moment, all the Ranks # end up sending their tokens. This needs to be fixed. + assert self.num_dispatchers is not None + assert self.max_num_tokens is not None num_dispatchers = self.num_dispatchers num_experts = local_num_experts max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index c8baefbd5..12e9918e0 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -862,6 +862,7 @@ class FusedMoEParallelConfig: use_ep: bool # whether to use EP or not all2all_backend: str # all2all backend for MoE communication + enable_eplb: bool # whether to enable expert load balancing @property def use_all2all_kernels(self): @@ -882,6 +883,16 @@ class FusedMoEParallelConfig: def use_deepep_ll_kernels(self): return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency" + @property + def use_batched_activation_format(self): + return self.use_deepep_ll_kernels or self.use_pplx_kernels + + @property + def use_naive_all2all_kernels(self): + return self.use_all2all_kernels and ( + self.all2all_backend in ["naive", "allgather_reducescatter"] + ) + @staticmethod def flatten_tp_across_dp_and_pcp( tp_size: int, dp_size: int, dp_rank: int, pcp_size: int, pcp_rank: int @@ -999,6 +1010,7 @@ class FusedMoEParallelConfig: ep_rank=0, use_ep=False, all2all_backend=vllm_parallel_config.all2all_backend, + enable_eplb=vllm_parallel_config.enable_eplb, ) # DP + EP / TP + EP / DP + TP + EP assert use_ep @@ -1017,6 +1029,24 @@ class FusedMoEParallelConfig: ep_rank=ep_rank, use_ep=True, all2all_backend=vllm_parallel_config.all2all_backend, + enable_eplb=vllm_parallel_config.enable_eplb, + ) + + @classmethod + def make_no_parallel(cls) -> "FusedMoEParallelConfig": + """For usage in CI/CD and testing.""" + return FusedMoEParallelConfig( + tp_size=1, + tp_rank=0, + pcp_size=1, + pcp_rank=0, + dp_size=1, + dp_rank=0, + ep_size=1, + ep_rank=0, + use_ep=False, + all2all_backend="naive", + enable_eplb=False, ) @@ -1026,8 +1056,11 @@ class FusedMoEConfig: num_experts: int experts_per_token: int hidden_dim: int - + intermediate_size_per_partition: int num_local_experts: int + activation: str + device: torch.device | str + routing_method: RoutingMethodType moe_parallel_config: FusedMoEParallelConfig # The activation type. diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index acc12d0da..0d8690638 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -7,7 +7,11 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( moe_permute, moe_unpermute, @@ -23,6 +27,19 @@ from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, apply_moe_activation, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8DynamicTensorSym, + kFp8DynamicTokenSym, + kFp8StaticChannelSym, + kFp8StaticTensorSym, + kNvfp4Dynamic, + kNvfp4Static, +) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + cutlass_group_gemm_supported, +) +from vllm.platforms import current_platform from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -238,29 +255,57 @@ def run_cutlass_moe_fp8( class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - e: int, - n: int, - k: int, - out_dtype: torch.dtype | None, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, - device: torch.dtype, + max_num_tokens: int | None = None, + num_dispatchers: int | None = None, ): + super().__init__( + moe_config=moe_config, + quant_config=quant_config, + max_num_tokens=max_num_tokens, + num_dispatchers=num_dispatchers, + ) assert quant_config.use_fp8_w8a8 - super().__init__(quant_config) - # E: num_experts - # N: intermediate size per partition - # K: hidden dim + e = moe_config.num_local_experts + n = moe_config.intermediate_size_per_partition + k = moe_config.hidden_dim + device = moe_config.device ab_strides1_c_strides2 = torch.full((e,), k, device=device, dtype=torch.int64) ab_strides2 = torch.full((e,), n, device=device, dtype=torch.int64) c_strides1 = torch.full((e,), 2 * n, device=device, dtype=torch.int64) - self.out_dtype = out_dtype + self.out_dtype = moe_config.in_dtype self.ab_strides1 = ab_strides1_c_strides2 self.ab_strides2 = ab_strides2 self.c_strides1 = c_strides1 self.c_strides2 = ab_strides1_c_strides2 + @staticmethod + def _supports_current_device() -> bool: + return cutlass_group_gemm_supported() + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + SUPPORTED_W_A = [ + (kFp8StaticChannelSym, kFp8DynamicTokenSym), + (kFp8StaticTensorSym, kFp8DynamicTensorSym), + (kFp8StaticTensorSym, kFp8StaticTensorSym), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["silu", "gelu", "swigluoai"] + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # Let PrepareAndFinalize::finalize() decide the impl. return TopKWeightAndReduceDelegate() @@ -291,7 +336,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): expert_num_tokens = expert_tokens_meta.expert_num_tokens use_batched_format = ( - self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts + self.activation_format() == mk.FusedMoEActivationFormat.BatchedExperts ) in_dtype = hidden_states.dtype @@ -324,20 +369,23 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): class CutlassExpertsFp8(CutlassExpertsFp8Base): - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + # CutlassExpertsFp8 does not support expert map, which is + # needed for STANDARD activation format kernels in DP/EP mode. + # Note that the BATCHED activation format does not use + # the expert map for identifying experts. + return not moe_parallel_config.use_all2all_kernels def supports_chunking(self) -> bool: return True def supports_expert_map(self) -> bool: - return True + return False def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # topk weights and reduction are fused in moe_unpermute cuda kernel @@ -365,26 +413,16 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): - def __init__( - self, - max_experts_per_worker: int, - num_dispatchers: int, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - assert max_experts_per_worker > 0 - self.max_experts_per_worker = max_experts_per_worker - self.num_dispatchers = num_dispatchers + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + # BATCHED activation format works with EP because + # expert_map is not used to identify experts (the + # info is encoded/managed by the P/F logic). + return True - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts def supports_chunking(self) -> bool: return False @@ -408,14 +446,15 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: num_dp = self.num_dispatchers assert num_dp is not None + experts_per_worker = self.moe_config.num_local_experts activation_out_dim = self.adjust_N_for_activation(N, activation) - workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K)) + workspace1 = (experts_per_worker, M * num_dp, max(N, K)) workspace2 = ( - self.max_experts_per_worker, + experts_per_worker, M * num_dp, max(activation_out_dim, K), ) - output = (self.max_experts_per_worker, M, K) + output = (experts_per_worker, M, K) return (workspace1, workspace2, output) @@ -601,34 +640,41 @@ def run_cutlass_moe_fp4( return -# Split into batched and non-batched class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( - self, - max_experts_per_worker: int, - out_dtype: torch.dtype, - quant_config: FusedMoEQuantConfig, - use_batched_format: bool = False, - ): - super().__init__(quant_config) - self.max_experts_per_worker = max_experts_per_worker - self.out_dtype = out_dtype - self.use_batched_format = use_batched_format + @staticmethod + def expects_unquantized_inputs( + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig + ) -> bool: + return True - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - if self.use_batched_format: - return ( - mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts, - ) - else: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def _supports_current_device() -> bool: + return current_platform.has_device_capability((10, 0)) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + return (weight_key, activation_key) == (kNvfp4Static, kNvfp4Dynamic) + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["silu", "gelu", "swigluoai"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + # CutlassExpertsFp4 does not support expert map, which is + # needed for STANDARD activation format kernels in EP mode. + return moe_parallel_config.ep_size == 1 + + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard def supports_expert_map(self) -> bool: return False @@ -640,7 +686,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): return TopKWeightAndReduceNoOP() def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: - return self.out_dtype if self.out_dtype is not None else act_dtype + return act_dtype def workspace_shapes( self, @@ -653,18 +699,9 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): expert_tokens_meta: mk.ExpertTokensMetadata | None, activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: - activation_out_dim = self.adjust_N_for_activation(N, activation) - workspace1: tuple[int, ...] = () - workspace2: tuple[int, ...] = () - output: tuple[int, ...] = () - if self.use_batched_format: - workspace1 = (self.max_experts_per_worker, M, max(N, K)) - workspace2 = (self.max_experts_per_worker, M, activation_out_dim) - output = (self.max_experts_per_worker, M, K) - else: - workspace1 = (M * topk, max(2 * N, K)) - workspace2 = (M * topk, N) - output = (M, K) + workspace1 = (M * topk, max(2 * N, K)) + workspace2 = (M * topk, N) + output = (M, K) return (workspace1, workspace2, output) def apply( @@ -869,10 +906,11 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute): c_strides2: torch.Tensor, s_strides1: torch.Tensor, s_strides2: torch.Tensor, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, group_size: int, ): - super().__init__(quant_config) + super().__init__(moe_config=moe_config, quant_config=quant_config) self.out_dtype = out_dtype self.a_strides1 = a_strides1 self.a_strides2 = a_strides2 @@ -884,13 +922,46 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute): self.s_strides2 = s_strides2 self.group_size = group_size - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + @staticmethod + def _supports_current_device() -> bool: + raise NotImplementedError( + "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + raise NotImplementedError( + "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + raise NotImplementedError( + "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_activation(activation: str) -> bool: + raise NotImplementedError( + "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + raise NotImplementedError( + "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. " + "This method should not be called." ) def supports_chunking(self) -> bool: @@ -947,7 +1018,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute): expert_num_tokens = None use_batched_format = ( - self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts + self.activation_format() == mk.FusedMoEActivationFormat.BatchedExperts ) assert not use_batched_format, "batched format not supported" @@ -1003,6 +1074,7 @@ def cutlass_moe_w4a8_fp8( s_strides1: torch.Tensor, s_strides2: torch.Tensor, quant_config: FusedMoEQuantConfig, + moe_config: FusedMoEConfig, activation: str = "silu", expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, @@ -1076,6 +1148,7 @@ def cutlass_moe_w4a8_fp8( c_strides2=c_strides2, s_strides1=s_strides1, s_strides2=s_strides2, + moe_config=moe_config, quant_config=quant_config, group_size=group_size, ), diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index a2e5a07fb..1d5d039b6 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -6,17 +6,15 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, FusedMoEQuantConfig, - fp8_w8a8_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( compute_aligned_M, deepgemm_moe_permute, deepgemm_unpermute_and_reduce, ) -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) @@ -26,9 +24,15 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8_packed_for_deepgemm, silu_mul_per_token_group_quant_fp8_colmajor, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8Dynamic128Sym, + kFp8Static128BlockSym, +) from vllm.utils.deep_gemm import ( DeepGemmQuantScaleFMT, get_mk_alignment_for_contiguous_layout, + is_deep_gemm_supported, m_grouped_fp8_gemm_nt_contiguous, ) from vllm.utils.import_utils import has_deep_gemm @@ -109,21 +113,42 @@ def _valid_deep_gemm( class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self, quant_config: FusedMoEQuantConfig): - super().__init__(quant_config) + def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig): + super().__init__(moe_config=moe_config, quant_config=quant_config) assert quant_config.block_shape == get_mk_alignment_for_contiguous_layout() assert quant_config.quant_dtype == torch.float8_e4m3fn assert not quant_config.per_act_token_quant assert not quant_config.per_out_ch_quant - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + @staticmethod + def _supports_current_device() -> bool: + return is_deep_gemm_supported() + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + SUPPORTED_W_A = [ + (kFp8Static128BlockSym, kFp8Dynamic128Sym), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["silu"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True def supports_chunking(self) -> bool: return True @@ -283,82 +308,3 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_map=expert_map, output=output, ) - - -def deep_gemm_moe_fp8( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - a1_scale: torch.Tensor | None = None, - a2_scale: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, -) -> torch.Tensor: - """ - This function computes a a8w8-quantized Mixture of Experts (MoE) layer - using two sets of quantized weights, w1_q and w2_q, and top-k gating - mechanism. The matrix multiplications are implemented with DeepGemm - grouped gemm. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - Shape: [M, K] - - w1 (torch.Tensor): The first set of fp8 quantized expert weights. - Shape: [num_experts, K, 2N] (the weights are passed transposed) - - w2 (torch.Tensor): The second set of fp8 quantized expert weights. - Shape: [num_experts, N, K] (the weights are passed transposed) - - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. - Shape: [num_experts] or [num_experts, 2N] - - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. - Shape: [num_experts] or [num_experts, K] - - topk_weights (torch.Tensor): The weights of each token->expert mapping. - - topk_ids (torch.Tensor): The token->expert mapping for topk_weights. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - activation (str): The activation function to apply after the first - MoE layer. - - global_num_experts (int): The total number of experts in the global - expert space. - - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert - parallel shard. - - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. - Shape: scalar or [M] - - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to - quantize the intermediate result between the gemms. - Shape: scalar or [M] - - Returns: - - torch.Tensor: The bfloat16 output tensor after applying the MoE layer. - """ - quant_config = fp8_w8a8_moe_quant_config( - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=get_mk_alignment_for_contiguous_layout(), - ) - - fn = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - DeepGemmExperts(quant_config), - ) - return fn( - hidden_states, - w1, - w2, - topk_weights, - topk_ids, - inplace=inplace, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) diff --git a/vllm/model_executor/layers/fused_moe/fallback.py b/vllm/model_executor/layers/fused_moe/fallback.py index 455639214..07e5b8005 100644 --- a/vllm/model_executor/layers/fused_moe/fallback.py +++ b/vllm/model_executor/layers/fused_moe/fallback.py @@ -6,6 +6,8 @@ from abc import ABC, abstractmethod import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig +from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC): @@ -16,18 +18,78 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC): experts: mk.FusedMoEPermuteExpertsUnpermute, fallback_experts: mk.FusedMoEPermuteExpertsUnpermute, ): - super().__init__(experts.quant_config) + super().__init__( + moe_config=experts.moe_config, quant_config=experts.quant_config + ) self.fallback_experts = fallback_experts self.experts = experts - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - assert ( - self.fallback_experts.activation_formats == self.experts.activation_formats + @staticmethod + def get_clses() -> tuple[ + type[mk.FusedMoEPermuteExpertsUnpermute], + type[mk.FusedMoEPermuteExpertsUnpermute], + ]: + """ + Get the cls for the experts and fallback experts. + + Subclasses should implement this method, so that + we have a consistent way to call the _supports_* + class methods below. + """ + raise NotImplementedError( + "Subclasses must return the cls for the experts and fallback experts." ) - return self.fallback_experts.activation_formats + + @classmethod + def activation_format( + cls: type["FallbackExperts"], + ) -> mk.FusedMoEActivationFormat: + experts_cls, fallback_cls = cls.get_clses() + assert experts_cls.activation_format() == fallback_cls.activation_format() + return experts_cls.activation_format() + + @classmethod + def _supports_current_device(cls) -> bool: + experts_cls, fallback_cls = cls.get_clses() + return ( + experts_cls._supports_current_device() + and fallback_cls._supports_current_device() + ) + + @classmethod + def _supports_no_act_and_mul(cls) -> bool: + experts_cls, fallback_cls = cls.get_clses() + return ( + experts_cls._supports_no_act_and_mul() + and fallback_cls._supports_no_act_and_mul() + ) + + @classmethod + def _supports_quant_scheme( + cls, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + experts_cls, fallback_cls = cls.get_clses() + return experts_cls._supports_quant_scheme( + weight_key, activation_key + ) and fallback_cls._supports_quant_scheme(weight_key, activation_key) + + @classmethod + def _supports_activation(cls, activation: str) -> bool: + experts_cls, fallback_cls = cls.get_clses() + return experts_cls._supports_activation( + activation + ) and fallback_cls._supports_activation(activation) + + @classmethod + def _supports_parallel_config( + cls, moe_parallel_config: FusedMoEParallelConfig + ) -> bool: + experts_cls, fallback_cls = cls.get_clses() + return experts_cls._supports_parallel_config( + moe_parallel_config + ) and fallback_cls._supports_parallel_config(moe_parallel_config) def supports_chunking(self) -> bool: assert ( diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py index 1651f3530..036ee2a2e 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py @@ -6,13 +6,22 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import envs from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kNvfp4Dynamic, + kNvfp4Static, +) +from vllm.platforms import current_platform from vllm.utils.flashinfer import ( flashinfer_cutedsl_grouped_gemm_nt_masked, - has_flashinfer_cutedsl_grouped_gemm_nt_masked, scaled_fp4_grouped_quantize, silu_and_mul_scaled_nvfp4_experts_quantize, ) @@ -20,54 +29,54 @@ from vllm.utils.flashinfer import ( logger = init_logger(__name__) -def is_valid_flashinfer_cutedsl_fused_moe( - hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor -) -> bool: - """ - Check if the given problem size is supported by the FlashInfer CuteDSL MoE - kernel. - """ - if not has_flashinfer_cutedsl_grouped_gemm_nt_masked(): - logger.debug_once( - "FlashInferCuteDSLExperts disabled: " - "flashinfer_cutedsl_fused_moe not available." - ) - return False - # Data type checks - if ( - w1.dtype != torch.uint8 - or w2.dtype != torch.uint8 - or hidden_states.dtype not in [torch.float32, torch.float16, torch.bfloat16] - ): - logger.debug_once( - "FlashInferCuteDSLExperts disabled: w1/w2 must be torch.uint8 " - f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be " - f"float32, float16, or bfloat16 (got {hidden_states.dtype})." - ) - return False - return True - - class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - out_dtype: torch.dtype, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, + max_num_tokens: int, + num_dispatchers: int, ): - super().__init__(quant_config) + super().__init__( + moe_config=moe_config, + quant_config=quant_config, + max_num_tokens=max_num_tokens, + num_dispatchers=num_dispatchers, + ) assert quant_config.quant_dtype == "nvfp4", ( "Only nvfp4 quantization are currently supported." ) - self.out_dtype = out_dtype + self.out_dtype = moe_config.in_dtype - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts + + @staticmethod + def _supports_current_device() -> bool: + return current_platform.is_device_capability_family(100) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + SUPPORTED_W_A = [ + (kNvfp4Static, kNvfp4Dynamic), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["silu"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True def supports_expert_map(self) -> bool: return False diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index c08894b81..8d5985875 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -5,13 +5,22 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - create_flashinfer_prepare_finalize, +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEParallelConfig, + FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8Dynamic128Sym, + kFp8Static128BlockSym, + kFp8StaticTensorSym, + kNvfp4Dynamic, + kNvfp4Static, +) +from vllm.platforms import current_platform from vllm.utils.flashinfer import ( flashinfer_cutlass_fused_moe, has_flashinfer_cutlass_fused_moe, @@ -50,40 +59,100 @@ def is_valid_flashinfer_cutlass_fused_moe( class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - out_dtype: torch.dtype, + moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig, - ep_rank: int = 0, - ep_size: int = 1, - tp_rank: int = 0, - tp_size: int = 1, - use_dp: bool = False, - use_deepseek_fp8_block_scale: bool = False, ): - super().__init__(quant_config) + super().__init__(moe_config, quant_config) assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), ( "Only nvfp4, fp8, bfloat16 and" " float16 quantization are currently supported." ) - self.ep_rank = ep_rank - self.ep_size = ep_size - self.tp_rank = tp_rank - self.tp_size = tp_size - self.out_dtype = out_dtype - self.use_dp = use_dp + self.ep_rank = moe_config.moe_parallel_config.ep_rank + self.ep_size = moe_config.moe_parallel_config.ep_size + self.tp_rank = moe_config.moe_parallel_config.tp_rank + self.tp_size = moe_config.moe_parallel_config.tp_size + self.out_dtype = moe_config.in_dtype + self.use_dp = moe_config.moe_parallel_config.dp_size > 1 # Enables DeepSeek-style FP8 block-scale path: # - pass per-block weight scales to the kernel # - skip input activation quantization (kernel applies scaling) - self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale + self.use_deepseek_fp8_block_scale = quant_config.is_block_quantized - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + @staticmethod + def expects_unquantized_inputs( + moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig + ) -> bool: + # NVFP4 TP kernels and FP8 block-quantized kernels apply + # input quantization inside FusedMoEPermuteExpertsUnpermute. return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, + quant_config.use_nvfp4_w4a4 + and not moe_config.moe_parallel_config.use_all2all_kernels + ) or (quant_config.use_fp8_w8a8 and quant_config.is_block_quantized) + + @staticmethod + def _supports_current_device() -> bool: + return ( + current_platform.is_cuda() + and ( + current_platform.is_device_capability((9, 0)) + or current_platform.is_device_capability_family(100) + ) + and has_flashinfer_cutlass_fused_moe() ) + @staticmethod + def _supports_no_act_and_mul() -> bool: + return True + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + # The following are supported by FlashInferExperts: + # * unquantized + # * fp8 static per-tensor on 9.0+ + # * fp8 block on 9.0 + # * nvfp4 on 10.0+ + + p = current_platform + scheme = (weight_key, activation_key) + return ( + ( + scheme + in [ + (None, None), + (kFp8StaticTensorSym, kFp8StaticTensorSym), + ] + ) + or ( + (scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym)) + and (p.is_device_capability((9, 0))) + ) + or ( + (scheme == (kNvfp4Static, kNvfp4Dynamic)) + and (p.is_device_capability_family(100)) + ) + ) + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["silu", "relu2_no_mul"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + # FLASHINFER_CUTLASS currently uses its down P/F, which does not + # work with SP. This will be removed in follow up after we get + # rid of the FlashInfer specific P/F function. + return ( + moe_parallel_config.dp_size == 1 + or moe_parallel_config.dp_size == moe_parallel_config.ep_size + ) + + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + def supports_expert_map(self) -> bool: return False @@ -231,85 +300,3 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): # No support for LoRA in flashinfer_cutlass_fused_moe. # See TODOs in flashinfer functions runMoe and runMoeMinLantency. raise NotImplementedError("LoRA is not supported for flashinfer_cutlass_moe") - - -def flashinfer_cutlass_moe_fp4( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - quant_config: FusedMoEQuantConfig, - inplace: bool = False, - activation: str = "silu", - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, -) -> torch.Tensor: - fused_experts = mk.FusedMoEModularKernel( - create_flashinfer_prepare_finalize( - use_dp=False, use_nvfp4=True, enable_alltoallv=False - ), - FlashInferExperts( - out_dtype=hidden_states.dtype, - quant_config=quant_config, - use_dp=False, - ), - ) - - return fused_experts( - hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=inplace, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - - -def flashinfer_cutlass_moe( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - quant_config: FusedMoEQuantConfig, - inplace: bool = False, - activation: str = "silu", - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - tp_rank: int = 0, - tp_size: int = 1, - ep_rank: int = 0, - ep_size: int = 1, - use_dp: bool = False, -) -> torch.Tensor: - fused_experts = mk.FusedMoEModularKernel( - create_flashinfer_prepare_finalize(use_dp=use_dp), - FlashInferExperts( - out_dtype=hidden_states.dtype, - quant_config=quant_config, - tp_rank=tp_rank, - tp_size=tp_size, - ep_rank=ep_rank, - ep_size=ep_size, - ), - ) - - return fused_experts( - hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=inplace, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 3bb5a23ab..e5d8a7ace 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -3,7 +3,12 @@ import torch -from vllm.model_executor.layers.fused_moe.config import RoutingMethodType +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + RoutingMethodType, +) from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( calculate_tile_tokens_dim, @@ -11,8 +16,107 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8Dynamic128Sym, + kFp8Static128BlockSym, + kFp8StaticTensorSym, +) +from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op +# +# Methods used by the oracle for kernel selection. +# + + +def _supports_current_device() -> bool: + """Supports only Blackwell-family GPUs.""" + p = current_platform + # Add check flashinfer trtllm is available + return p.is_cuda() and p.is_device_capability_family(100) + + +def _supports_no_act_and_mul() -> bool: + """Does not support non-gated MoE (i.e. Nanotron-Mini).""" + return False + + +def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, +) -> bool: + """Supports Fp8 per-tensor and Fp8 block.""" + SUPPORTED_W_A = [ + (kFp8Static128BlockSym, kFp8Dynamic128Sym), + (kFp8StaticTensorSym, kFp8StaticTensorSym), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + +def _supports_activation(activation: str) -> bool: + """Supports silu activation only.""" + return activation in ["silu"] + + +def _supports_routing_method( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + routing_method: RoutingMethodType, +) -> bool: + """Monolithic kernels need to express router support.""" + if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym): + # NOTE(rob): potentially allow others here. This is a conservative list. + return routing_method in [ + RoutingMethodType.DeepSeekV3, + RoutingMethodType.Renormalize, + RoutingMethodType.RenormalizeNaive, + ] + elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym): + # NOTE(rob): kernel requires Llama4. + return routing_method == RoutingMethodType.Llama4 + + else: + raise ValueError("Unsupported quantization scheme.") + + +def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + """Supports TRTLLM Kernel does not support EPLB.""" + return not moe_parallel_config.enable_eplb + + +def is_supported_config_trtllm( + moe_config: FusedMoEConfig, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + activation_format: mk.FusedMoEActivationFormat, +) -> tuple[bool, str | None]: + """ + This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config + """ + + def _make_reason(reason: str) -> str: + return f"kernel does not support {reason}" + + if not _supports_current_device(): + return False, _make_reason("current device") + elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()): + return False, _make_reason("no act_and_mul MLP layer") + elif not _supports_activation(moe_config.activation): + return False, _make_reason(f"{moe_config.activation} activation") + elif not _supports_quant_scheme(weight_key, activation_key): + return False, _make_reason("quantization scheme") + elif not _supports_parallel_config(moe_config.moe_parallel_config): + return False, _make_reason("parallel config") + elif not _supports_routing_method( + weight_key, activation_key, moe_config.routing_method + ): + return False, _make_reason("routing method") + elif activation_format != mk.FusedMoEActivationFormat.Standard: + return False, _make_reason("activation format") + + return True, None + def flashinfer_fused_moe_blockscale_fp8( routing_logits: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index fb9346439..8e45c0e41 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -5,7 +5,11 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, @@ -17,7 +21,17 @@ from vllm.model_executor.layers.fused_moe.utils import ( normalize_batched_scales_shape, normalize_scales_shape, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + group_broadcast, + kFp8Dynamic128Sym, + kFp8DynamicTensorSym, + kFp8DynamicTokenSym, + kFp8Static128BlockSym, + kFp8StaticChannelSym, + kFp8StaticTensorSym, +) +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -633,25 +647,62 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, max_num_tokens: int, num_dispatchers: int, - quant_config: FusedMoEQuantConfig, ): - super().__init__(quant_config) + super().__init__( + moe_config=moe_config, + quant_config=quant_config, + max_num_tokens=max_num_tokens, + num_dispatchers=num_dispatchers, + ) assert not self.quant_config.use_int8_w8a8, "NYI" assert not self.quant_config.use_int8_w8a16, "NYI" assert not self.quant_config.use_int4_w4a16, "NYI" assert self.quant_config.ocp_mx_scheme is None, "NYI" - self.max_num_tokens = max_num_tokens - self.num_dispatchers = num_dispatchers - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts, + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts + + @staticmethod + def _supports_current_device() -> bool: + raise NotImplementedError( + "NaiveBatchedExperts is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + raise NotImplementedError( + "NaiveBatchedExperts is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + raise NotImplementedError( + "NaiveBatchedExperts is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_activation(activation: str) -> bool: + raise NotImplementedError( + "NaiveBatchedExperts is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + raise NotImplementedError( + "NaiveBatchedExperts is not yet used by an Oracle. " + "This method should not be called." ) def supports_chunking(self) -> bool: @@ -675,6 +726,8 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_tokens_meta: mk.ExpertTokensMetadata | None, activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + assert self.num_dispatchers is not None + assert self.max_num_tokens is not None num_dp = self.num_dispatchers num_experts = local_num_experts workspace13 = (num_experts, self.max_num_tokens * num_dp, K) @@ -826,29 +879,69 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, max_num_tokens: int, num_dispatchers: int, - quant_config: FusedMoEQuantConfig, ): - super().__init__(quant_config) + super().__init__( + moe_config=moe_config, + quant_config=quant_config, + max_num_tokens=max_num_tokens, + num_dispatchers=num_dispatchers, + ) assert not self.quant_config.use_int8_w8a8, "NYI" assert not self.quant_config.use_int8_w8a16, "NYI" assert not self.quant_config.use_int4_w4a16, "NYI" assert self.quant_config.ocp_mx_scheme is None, "NYI" - assert max_num_tokens > 0 - assert num_dispatchers > 0 - self.max_num_tokens = max_num_tokens - self.num_dispatchers = num_dispatchers - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts, + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts + + @staticmethod + def _supports_current_device() -> bool: + return current_platform.is_cuda_alike() + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + p = current_platform + device_supports_fp8 = (p.is_rocm() and p.rocm.on_gfx9()) or ( + p.is_cuda() and p.has_device_capability((8, 9)) ) + SUPPORTED_W_A_FP8 = [ + (kFp8Static128BlockSym, kFp8Dynamic128Sym), + (kFp8StaticChannelSym, kFp8DynamicTokenSym), + (kFp8StaticTensorSym, kFp8StaticTensorSym), + (kFp8StaticTensorSym, kFp8DynamicTensorSym), + ] + return (weight_key, activation_key) == (None, None) or ( + device_supports_fp8 and (weight_key, activation_key) in SUPPORTED_W_A_FP8 + ) + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in [ + "silu", + "gelu", + "swigluoai", + "silu_no_mul", + "gelu_no_mul", + "relu2_no_mul", + ] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True + def supports_chunking(self) -> bool: return False @@ -870,6 +963,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_tokens_meta: mk.ExpertTokensMetadata | None, activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + assert self.num_dispatchers is not None + assert self.max_num_tokens is not None num_dp = self.num_dispatchers num_experts = local_num_experts max_num_tokens = self.max_num_tokens diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index be9dddb87..603c8fc96 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -8,7 +8,11 @@ import torch import vllm._custom_ops as ops import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( batched_moe_align_block_size, moe_align_block_size, @@ -27,6 +31,13 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_moe_intermediate_size, marlin_quant_input, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8Static128BlockSym, + kFp8StaticChannelSym, + kFp8StaticTensorSym, + kNvfp4Static, +) from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types @@ -522,7 +533,10 @@ def batched_fused_marlin_moe( class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, + max_num_tokens: int | None = None, + num_dispatchers: int | None = None, w13_g_idx: torch.Tensor | None = None, w2_g_idx: torch.Tensor | None = None, w13_g_idx_sort_indices: torch.Tensor | None = None, @@ -541,7 +555,51 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): self.w13_g_idx_sort_indices = w13_g_idx_sort_indices self.w2_g_idx_sort_indices = w2_g_idx_sort_indices self.is_k_full = is_k_full - super().__init__(quant_config) + super().__init__( + moe_config=moe_config, + quant_config=quant_config, + max_num_tokens=max_num_tokens, + num_dispatchers=num_dispatchers, + ) + + @staticmethod + def _supports_current_device() -> bool: + p = current_platform + return p.is_cuda() and p.has_device_capability((7, 5)) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + # TODO(rob): add int4, mxfp4, int8 as integrations + # are migrated to use the oracle one-by-one. + SUPPORTED_W = [ + kFp8Static128BlockSym, + kFp8StaticChannelSym, + kFp8StaticTensorSym, + kNvfp4Static, + ] + return weight_key in SUPPORTED_W + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in [ + "silu", + "gelu", + "swigluoai", + "silu_no_mul", + "gelu_no_mul", + "relu2_no_mul", + ] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True @property def quant_type_id(self) -> int: @@ -587,38 +645,15 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): class MarlinExperts(MarlinExpertsBase): - def __init__( - self, - quant_config: FusedMoEQuantConfig, - w13_g_idx: torch.Tensor | None = None, - w2_g_idx: torch.Tensor | None = None, - w13_g_idx_sort_indices: torch.Tensor | None = None, - w2_g_idx_sort_indices: torch.Tensor | None = None, - is_k_full: bool = True, - ): - super().__init__( - quant_config, - w13_g_idx, - w2_g_idx, - w13_g_idx_sort_indices, - w2_g_idx_sort_indices, - is_k_full, - ) - def supports_expert_map(self) -> bool: return True def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceNoOP() - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard def supports_chunking(self) -> bool: return True @@ -714,9 +749,10 @@ class MarlinExperts(MarlinExpertsBase): class BatchedMarlinExperts(MarlinExpertsBase): def __init__( self, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, max_num_tokens: int, num_dispatchers: int, - quant_config: FusedMoEQuantConfig, w13_g_idx: torch.Tensor | None = None, w2_g_idx: torch.Tensor | None = None, w13_g_idx_sort_indices: torch.Tensor | None = None, @@ -724,15 +760,16 @@ class BatchedMarlinExperts(MarlinExpertsBase): is_k_full: bool = True, ): super().__init__( - quant_config, - w13_g_idx, - w2_g_idx, - w13_g_idx_sort_indices, - w2_g_idx_sort_indices, - is_k_full, + moe_config=moe_config, + quant_config=quant_config, + max_num_tokens=max_num_tokens, + num_dispatchers=num_dispatchers, + w13_g_idx=w13_g_idx, + w2_g_idx=w2_g_idx, + w13_g_idx_sort_indices=w13_g_idx_sort_indices, + w2_g_idx_sort_indices=w2_g_idx_sort_indices, + is_k_full=is_k_full, ) - self.max_num_tokens = max_num_tokens - self.num_dispatchers = num_dispatchers def supports_expert_map(self) -> bool: return True @@ -740,14 +777,9 @@ class BatchedMarlinExperts(MarlinExpertsBase): def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceDelegate() - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts def supports_chunking(self) -> bool: return False @@ -763,9 +795,11 @@ class BatchedMarlinExperts(MarlinExpertsBase): expert_tokens_meta: mk.ExpertTokensMetadata | None, activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + assert self.num_dispatchers is not None + assert self.max_num_tokens is not None num_dispatchers = self.num_dispatchers num_experts = local_num_experts - max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens + max_num_tokens = self.max_num_tokens workspace13 = (num_experts * max_num_tokens * num_dispatchers, max(K, N * 2)) workspace2 = (num_experts * max_num_tokens * num_dispatchers, N) output = (num_experts, max_num_tokens * num_dispatchers, K) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d069d81f5..7e7d59fb9 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -19,13 +19,11 @@ from vllm.model_executor.layers.batch_invariant import ( ) from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEConfig, + FusedMoEParallelConfig, FusedMoEQuantConfig, _get_config_dtype_str, ) -from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm, - deep_gemm_moe_fp8, -) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size, ) @@ -44,9 +42,16 @@ from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6 from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8Dynamic128Sym, + kFp8DynamicTokenSym, + kFp8Static128BlockSym, + kFp8StaticChannelSym, + kFp8StaticTensorSym, +) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer logger = init_logger(__name__) @@ -1534,66 +1539,36 @@ def fused_experts( global_num_experts: int = -1, expert_map: torch.Tensor | None = None, quant_config: FusedMoEQuantConfig | None = None, - allow_deep_gemm: bool = False, ) -> torch.Tensor: if quant_config is None: quant_config = FUSED_MOE_UNQUANTIZED_CONFIG - # For now, disable DeepGemm for small N (<= 512) until better - # permute/unpermute ops are available. - # However, on B200, we use DeepGemm for all cases because they only support - # E8M0 scale, which means we requantize the weight and input to the specific - # scale. Fallen back to cutlass or triton for some cases would cause - # accuracy issue. - if ( - allow_deep_gemm - and quant_config.use_fp8_w8a8 - and (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2)) - ): - assert quant_config is not None - return deep_gemm_moe_fp8( - hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=inplace, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=quant_config.w1_scale, - w2_scale=quant_config.w2_scale, - a1_scale=quant_config.a1_scale, - a2_scale=quant_config.a2_scale, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - else: - return dispatch_fused_experts_func(inplace)( - hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=quant_config.use_fp8_w8a8, - use_int8_w8a8=quant_config.use_int8_w8a8, - use_int8_w8a16=quant_config.use_int8_w8a16, - use_int4_w4a16=quant_config.use_int4_w4a16, - ocp_mx_scheme=quant_config.ocp_mx_scheme, - per_channel_quant=quant_config.per_act_token_quant, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=quant_config.w1_scale, - w2_scale=quant_config.w2_scale, - w1_zp=quant_config.w1_zp, - w2_zp=quant_config.w2_zp, - a1_scale=quant_config.a1_scale, - a2_scale=quant_config.a2_scale, - block_shape=quant_config.block_shape, - w1_bias=quant_config.w1_bias, - w2_bias=quant_config.w2_bias, - ) + return dispatch_fused_experts_func(inplace)( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=quant_config.use_fp8_w8a8, + use_int8_w8a8=quant_config.use_int8_w8a8, + use_int8_w8a16=quant_config.use_int8_w8a16, + use_int4_w4a16=quant_config.use_int4_w4a16, + ocp_mx_scheme=quant_config.ocp_mx_scheme, + per_channel_quant=quant_config.per_act_token_quant, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=quant_config.w1_scale, + w2_scale=quant_config.w2_scale, + w1_zp=quant_config.w1_zp, + w2_zp=quant_config.w2_zp, + a1_scale=quant_config.a1_scale, + a2_scale=quant_config.a2_scale, + block_shape=quant_config.block_shape, + w1_bias=quant_config.w1_bias, + w2_bias=quant_config.w2_bias, + ) def _get_config_quant_dtype( @@ -1924,19 +1899,53 @@ def fused_experts_impl( class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, ): - super().__init__(quant_config) + super().__init__(moe_config, quant_config) - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + @staticmethod + def _supports_current_device() -> bool: + return current_platform.is_cuda_alike() + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + p = current_platform + device_supports_fp8 = (p.is_rocm() and p.rocm.on_gfx9()) or ( + p.is_cuda() and p.has_device_capability((8, 9)) ) + if not device_supports_fp8: + return (weight_key, activation_key) == (None, None) + + SUPPORTED_W_A = [ + (None, None), + (kFp8Static128BlockSym, kFp8Dynamic128Sym), + (kFp8StaticChannelSym, kFp8DynamicTokenSym), + (kFp8StaticTensorSym, kFp8DynamicTokenSym), + (kFp8StaticTensorSym, kFp8StaticTensorSym), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["silu", "gelu", "swigluoai"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True + def supports_chunking(self) -> bool: return True @@ -2111,11 +2120,43 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): class TritonWNA16Experts(TritonExperts): - def __init__( - self, - quant_config: FusedMoEQuantConfig, - ): - super().__init__(quant_config) + @staticmethod + def _supports_current_device() -> bool: + raise NotImplementedError( + "TritonWNA16Experts is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + raise NotImplementedError( + "TritonWNA16Experts is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + raise NotImplementedError( + "TritonWNA16Experts is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_activation(activation: str) -> bool: + raise NotImplementedError( + "TritonWNA16Experts is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + raise NotImplementedError( + "TritonWNA16Experts is not yet used by an Oracle. " + "This method should not be called." + ) def apply( self, @@ -2254,10 +2295,12 @@ class TritonWNA16Experts(TritonExperts): def modular_triton_fused_moe( - quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + shared_experts: torch.nn.Module | None = None, ) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), - TritonExperts(quant_config), + TritonExperts(moe_config, quant_config), shared_experts, ) diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index c4bc1824a..b209820cd 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -9,12 +9,16 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEParallelConfig, FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, +) from vllm.triton_utils import tl, triton from vllm.utils.import_utils import has_triton_kernels @@ -241,8 +245,43 @@ def make_routing_data( class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self, quant_config: FusedMoEQuantConfig): - super().__init__(quant_config) + @staticmethod + def _supports_current_device() -> bool: + raise NotImplementedError( + "OAITritonExperts is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + raise NotImplementedError( + "OAITritonExperts is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + raise NotImplementedError( + "OAITritonExperts is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_activation(activation: str) -> bool: + raise NotImplementedError( + "OAITritonExperts is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + raise NotImplementedError( + "OAITritonExperts is not yet used by an Oracle. " + "This method should not be called." + ) def supports_expert_map(self) -> bool: return True @@ -297,19 +336,9 @@ class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): class OAITritonExperts(BaseOAITritonExperts): - def __init__(self, quant_config: FusedMoEQuantConfig): - # TODO (varun) : Enable activation quantization - assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" - super().__init__(quant_config) - - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard def supports_chunking(self) -> bool: return True @@ -391,19 +420,9 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts): One use case for it is to inject LoRA modules on the activation and moe_sum. """ - def __init__(self, quant_config: FusedMoEQuantConfig): - # TODO (varun) : Enable activation quantization - assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" - super().__init__(quant_config) - - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard def supports_chunking(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 4f6604530..2583a3a11 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -330,7 +330,6 @@ class FusedMoE(CustomOp): is_sequence_parallel=False, expert_mapping: list[tuple[str, str, int, str]] | None = None, n_shared_experts: int | None = None, - routing_method_type: RoutingMethodType | None = None, router_logits_dtype: torch.dtype | None = None, ): super().__init__() @@ -519,10 +518,43 @@ class FusedMoE(CustomOp): self.apply_router_weight_on_input = apply_router_weight_on_input self.activation = activation + # TODO(bnell): in next PR move capture back to layer + capture: Callable[[torch.Tensor], None] | None = None + if ( + self.vllm_config.model_config is not None + and self.vllm_config.model_config.enable_return_routed_experts + ): + # In dummy runs, the capturer is not initialized. + capturer = RoutedExpertsCapturer.get_instance() + if capturer is not None: + capture = lambda topk_ids: capturer.capture(self.layer_id, topk_ids) + + self.router = create_fused_moe_router( + top_k=top_k, + global_num_experts=self.global_num_experts, + eplb_state=self.eplb_state, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + num_fused_shared_experts=self.num_fused_shared_experts, + enable_eplb=enable_eplb, + # TODO(bnell): once we can construct the MK at init time, we + # can make this a value. + indices_type_getter=lambda: self.quant_method.topk_indices_dtype, + capture=capture, + ) + self.routing_method_type: RoutingMethodType = self.router.routing_method_type + self.moe_config: FusedMoEConfig = FusedMoEConfig( num_experts=self.global_num_experts, experts_per_token=top_k, hidden_dim=hidden_size, + intermediate_size_per_partition=self.intermediate_size_per_partition, num_local_experts=self.local_num_experts, moe_parallel_config=self.moe_parallel_config, in_dtype=moe_in_dtype, @@ -531,6 +563,9 @@ class FusedMoE(CustomOp): has_bias=has_bias, is_act_and_mul=is_act_and_mul, is_lora_enabled=vllm_config.lora_config is not None, + activation=activation, + device=vllm_config.device_config.device, + routing_method=self.routing_method_type, ) self.moe_config_use_flashinfer_cutlass_kernels = ( self.moe_config.use_flashinfer_cutlass_kernels @@ -594,39 +629,6 @@ class FusedMoE(CustomOp): self.batched_hidden_states: torch.Tensor | None = None self.batched_router_logits: torch.Tensor | None = None - # TODO(bnell): in next PR move capture back to layer - capture: Callable[[torch.Tensor], None] | None = None - if ( - self.vllm_config.model_config is not None - and self.vllm_config.model_config.enable_return_routed_experts - ): - # In dummy runs, the capturer is not initialized. - capturer = RoutedExpertsCapturer.get_instance() - if capturer is not None: - capture = lambda topk_ids: capturer.capture(self.layer_id, topk_ids) - - self.router = create_fused_moe_router( - top_k=top_k, - global_num_experts=self.global_num_experts, - eplb_state=self.eplb_state, - renormalize=renormalize, - use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, - topk_group=topk_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - num_fused_shared_experts=self.num_fused_shared_experts, - enable_eplb=enable_eplb, - # TODO(bnell): once we can construct the MK at init time, we - # can make this a value. - indices_type_getter=lambda: self.quant_method.topk_indices_dtype, - routing_method_type=routing_method_type, - capture=capture, - ) - self.routing_method_type: RoutingMethodType = self.router.routing_method_type - # Note: maybe_init_modular_kernel should only be called by # prepare_communication_buffer_for_model. # This is called after all weight loading and post-processing, so it diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 962d0fe78..6aff6401e 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -13,6 +13,7 @@ import vllm.envs as envs from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, ) @@ -22,6 +23,9 @@ from vllm.model_executor.layers.fused_moe.utils import ( count_expert_num_tokens, disable_inplace, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, +) from vllm.utils.math_utils import cdiv from vllm.v1.worker.ubatching import ( dbo_enabled, @@ -374,18 +378,51 @@ class FusedMoEPermuteExpertsUnpermute(ABC): def __init__( self, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, + max_num_tokens: int | None = None, + num_dispatchers: int | None = None, ): """ + moe_config: MoE layer configuration. quant_config: Quantization parameters for this experts instance. """ - self.quant_config = quant_config + if self.activation_format() == FusedMoEActivationFormat.Standard and ( + max_num_tokens is not None or num_dispatchers is not None + ): + raise ValueError( + "max_num_tokens and num_dispatchers should only be set for " + "BatchedExperts activation format." + ) + elif self.activation_format() == FusedMoEActivationFormat.BatchedExperts and ( + max_num_tokens is None or num_dispatchers is None + ): + raise ValueError( + "max_num_tokens and num_dispatchers must be set for " + "BatchedExperts activation format." + ) - @property + self.moe_config = moe_config + self.quant_config = quant_config + self.max_num_tokens = max_num_tokens + self.num_dispatchers = num_dispatchers + + @staticmethod + def expects_unquantized_inputs( + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig + ) -> bool: + """ + Whether or not the PrepareFinalize should defer input quantization + in the prepare step. If True, then the Experts kernel will + execute the input quantization itself. + + Sample subclasses that override are AITER and FlashInfer CUTLASS. + """ + return False + + @staticmethod @abstractmethod - def activation_formats( - self, - ) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]: + def activation_format() -> FusedMoEActivationFormat: """ A property which is a tuple of the input and output activation formats for the 'apply' method. @@ -435,6 +472,78 @@ class FusedMoEPermuteExpertsUnpermute(ABC): return E, M, N, K, topk + # + # Various helpers for registering support for various features. + # Used by the oracle to select a particular kernel for a deployment. + # + + @staticmethod + def is_supported_config( + cls: type["FusedMoEPermuteExpertsUnpermute"], + moe_config: FusedMoEConfig, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + activation_format: FusedMoEActivationFormat, + ) -> tuple[bool, str | None]: + def _make_reason(reason: str) -> str: + return f"kernel does not support {reason}" + + if not cls._supports_current_device(): + return False, _make_reason("current device") + elif not (moe_config.is_act_and_mul or cls._supports_no_act_and_mul()): + return False, _make_reason("no act_and_mul MLP layer") + elif not cls._supports_activation(moe_config.activation): + return False, _make_reason(f"{moe_config.activation} activation") + elif not cls._supports_quant_scheme(weight_key, activation_key): + return False, _make_reason("quantization scheme") + elif not cls._supports_parallel_config(moe_config.moe_parallel_config): + return False, _make_reason("parallel config") + elif activation_format != cls.activation_format(): + return False, _make_reason(f"{activation_format.value} activation format") + return True, None + + @staticmethod + @abstractmethod + def _supports_current_device() -> bool: + """ + Whether the kernel supports the current device type + (compute cability and current platform). + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def _supports_no_act_and_mul() -> bool: + """ + Whether the kernel supports act_and_mul=False, i.e. + non-gated MoE models like Nemotron-Nano. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + raise NotImplementedError + + @staticmethod + @abstractmethod + def _supports_activation(activation: str) -> bool: + """ + Whether the kernel supports a particular act function. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + """ + Whether the kernel supports deployment in expert parallel. + """ + raise NotImplementedError + # # Various helpers for accessing quantization parameters from the # quant_config. @@ -715,12 +824,12 @@ class FusedMoEModularKernel(torch.nn.Module): self._post_init_setup() assert ( - prepare_finalize.activation_format == fused_experts.activation_formats[0] + prepare_finalize.activation_format == fused_experts.activation_format() ), ( f"{prepare_finalize.__class__.__name__}." f"{prepare_finalize.activation_format} == " f"{fused_experts.__class__.__name__}." - f"{fused_experts.activation_formats[0]}" + f"{fused_experts.activation_format()}" ) def _post_init_setup(self): diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index 6872b542f..cdf2d291b 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -14,6 +14,12 @@ from vllm.model_executor.layers.fused_moe.config import ( fp8_w8a8_moe_quant_config, fp8_w8a16_moe_quant_config, ) +from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import ( + is_supported_config_trtllm, +) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, get_flashinfer_moe_backend, @@ -26,133 +32,307 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_fp8_moe_layer_for_marlin, ) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - cutlass_group_gemm_supported, +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, ) -from vllm.platforms import current_platform -from vllm.utils.deep_gemm import is_deep_gemm_supported -from vllm.utils.flashinfer import has_flashinfer_moe -from vllm.utils.import_utils import has_deep_gemm logger = init_logger(__name__) class Fp8MoeBackend(Enum): - NONE = 0 - FLASHINFER_TRTLLM = 1 - FLASHINFER_CUTLASS = 2 - DEEPGEMM = 3 - MARLIN = 4 - TRITON = 5 - AITER = 6 - VLLM_CUTLASS = 7 + NONE = "NONE" + FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM" + FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS" + DEEPGEMM = "DEEPGEMM" + BATCHED_DEEPGEMM = "BATCHED_DEEPGEMM" + MARLIN = "MARLIN" + TRITON = "TRITON" + BATCHED_TRITON = "BATCHED_TRITON" + AITER = "AITER" + VLLM_CUTLASS = "VLLM_CUTLASS" + BATCHED_VLLM_CUTLASS = "BATCHED_VLLM_CUTLASS" + + +def backend_to_kernel_cls( + backend: Fp8MoeBackend, +) -> type[mk.FusedMoEPermuteExpertsUnpermute]: + if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: + raise NotImplementedError + + elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS: + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts, + ) + + return FlashInferExperts + + elif backend == Fp8MoeBackend.DEEPGEMM: + from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts, + ) + + return TritonOrDeepGemmExperts + + elif backend == Fp8MoeBackend.BATCHED_DEEPGEMM: + from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + BatchedDeepGemmExperts, + ) + + return BatchedDeepGemmExperts + + elif backend == Fp8MoeBackend.MARLIN: + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + MarlinExperts, + ) + + return MarlinExperts + + elif backend == Fp8MoeBackend.TRITON: + from vllm.model_executor.layers.fused_moe.fused_moe import ( + TritonExperts, + ) + + return TritonExperts + + elif backend == Fp8MoeBackend.BATCHED_TRITON: + from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedTritonExperts, + ) + + return BatchedTritonExperts + + elif backend == Fp8MoeBackend.AITER: + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + AiterExperts, + ) + + return AiterExperts + + elif backend == Fp8MoeBackend.VLLM_CUTLASS: + from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import ( + TritonOrCutlassExperts, + ) + + return TritonOrCutlassExperts + + elif backend == Fp8MoeBackend.BATCHED_VLLM_CUTLASS: + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + CutlassBatchedExpertsFp8, + ) + + return CutlassBatchedExpertsFp8 + + else: + raise ValueError(f"Unknown FP8 MoE backend: {backend.value}") def select_fp8_moe_backend( - block_quant: bool, - tp_size: int, - with_lora_support: bool, - is_act_and_mul: bool, + config: FusedMoEConfig, + weight_key: QuantKey | None, + activation_key: QuantKey | None, allow_vllm_cutlass: bool = False, -) -> Fp8MoeBackend: +) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]: """ Select the primary FP8 MoE backend Note: Shape-specific fallbacks may still occur at runtime. """ - # TODO(rob): in a future PR, we will query each mk for - # supported features and return the mk directly, just like - # we do for the Attention Backend. + k_cls: type[mk.FusedMoEPermuteExpertsUnpermute] | None = None - if with_lora_support: - return Fp8MoeBackend.TRITON + if config.is_lora_enabled: + return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON) - def _make_log_backend(backend_name: str): - return f"Using {backend_name} backend for FP8 MoE" + # NOTE: the kernels are selected in the following order. + AVAILABLE_BACKENDS = [ + Fp8MoeBackend.AITER, + Fp8MoeBackend.FLASHINFER_TRTLLM, + Fp8MoeBackend.FLASHINFER_CUTLASS, + Fp8MoeBackend.DEEPGEMM, + Fp8MoeBackend.BATCHED_DEEPGEMM, + Fp8MoeBackend.VLLM_CUTLASS, + Fp8MoeBackend.BATCHED_VLLM_CUTLASS, + Fp8MoeBackend.TRITON, + Fp8MoeBackend.BATCHED_TRITON, + Fp8MoeBackend.MARLIN, + ] - # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100. - if ( - current_platform.is_cuda() - and ( - current_platform.is_device_capability_family(100) - or current_platform.is_device_capability(90) - ) - and envs.VLLM_USE_FLASHINFER_MOE_FP8 - and has_flashinfer_moe() - ): - backend = get_flashinfer_moe_backend() - if backend == FlashinferMoeBackend.TENSORRT_LLM: - logger.info_once(_make_log_backend("FlashInfer TRTLLM")) - if not is_act_and_mul: - raise ValueError( - "FlashInfer TRTLLM FP8 MoE backend only supports " - "act_and_mul gate_up_project fusion. Please set " - "VLLM_USE_FLASHINFER_MOE_FP8=throughput to use the " - "FlashInfer CUTLASS backend instead." - ) - return Fp8MoeBackend.FLASHINFER_TRTLLM - else: - if block_quant and current_platform.is_device_capability_family(100): - raise ValueError( - "FlashInfer FP8 MoE throughput backend does not " - "support block quantization on SM100. Please use " - "VLLM_FLASHINFER_MOE_BACKEND=latency to use the " - "FlashInfer TRTLLM backend instead." - ) - logger.info_once(_make_log_backend("FlashInfer CUTLASS")) - return Fp8MoeBackend.FLASHINFER_CUTLASS + # NOTE(rob): We need to peak into the P/F selection to determine + # if we are using the batched or standard expert format, which + # if not ideal. Once we unify TP + DP/EP, we can select P/F first. + activation_format = ( + mk.FusedMoEActivationFormat.BatchedExperts + if config.moe_parallel_config.use_batched_activation_format + else mk.FusedMoEActivationFormat.Standard + ) - # weight-only path for older GPUs without native FP8 - if ( - current_platform.is_cuda() and not current_platform.has_device_capability(89) - ) or envs.VLLM_TEST_FORCE_FP8_MARLIN: - logger.info_once(_make_log_backend("Marlin"), scope="local") - return Fp8MoeBackend.MARLIN - - # Determine if we should use DeepGEMM with block-quantized weights: - # - If explicitly set by user, respect their choice - # - If not explicitly set (default), disable when TP size is >= 8 - moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM - if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and tp_size >= 8: - moe_use_deep_gemm = False - logger.info_once( - "DeepGEMM MoE is disabled by default when TP size is >= 8. " - "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.", - scope="local", + def _make_log_backend(backend: Fp8MoeBackend): + available_backend_strs = [b.value for b in AVAILABLE_BACKENDS] + return ( + f"Using {backend.value} Fp8 MoE backend out " + f"of potential backends: {available_backend_strs}." ) - use_deep_gemm = envs.VLLM_USE_DEEP_GEMM - if not is_deep_gemm_supported(): - use_deep_gemm = False - logger.info_once( - "DeepGEMM is disabled because the platform does not support it.", - scope="local", - ) - - if use_deep_gemm and moe_use_deep_gemm and block_quant and is_act_and_mul: - if not has_deep_gemm(): - logger.warning_once( - "DeepGEMM backend requested but not available.", scope="local" + def _make_log_unsupported(backend: Fp8MoeBackend, reason: str | None) -> str: + if reason: + return ( + f"FP8 MoE backend {backend.value} does not support the " + f"deployment configuration since {reason}." + ) + else: + return ( + f"FP8 MoE backend '{backend.value}' does not support the " + "deployment configuration." ) - elif is_deep_gemm_supported(): - logger.info_once(_make_log_backend("DeepGEMM"), scope="local") - return Fp8MoeBackend.DEEPGEMM - if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE: - logger.info_once(_make_log_backend("ROCm AITER"), scope="local") - return Fp8MoeBackend.AITER + def _return_or_raise( + backend: Fp8MoeBackend, + config: FusedMoEConfig, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + activation_format: mk.FusedMoEActivationFormat, + ) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]: + k_cls = backend_to_kernel_cls(backend) + supported, reason = k_cls.is_supported_config( + k_cls, config, weight_key, activation_key, activation_format + ) + if supported: + logger.info_once(_make_log_backend(backend), scope="local") + return backend, k_cls + raise ValueError(_make_log_unsupported(backend, reason)) - if ( - allow_vllm_cutlass - and not block_quant - and cutlass_group_gemm_supported() - and is_act_and_mul - ): - logger.info_once(_make_log_backend("vLLM CUTLASS"), scope="local") - return Fp8MoeBackend.VLLM_CUTLASS + # Handle explicit FlashInfer FP8 configuration. + if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP8"): + if not envs.VLLM_USE_FLASHINFER_MOE_FP8: + # If the user rejects FlashInfer remove those backends. + AVAILABLE_BACKENDS.remove(Fp8MoeBackend.FLASHINFER_TRTLLM) + AVAILABLE_BACKENDS.remove(Fp8MoeBackend.FLASHINFER_CUTLASS) - # default to Triton - logger.info_once(_make_log_backend("Triton"), scope="local") - return Fp8MoeBackend.TRITON + elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"): + # If user is explicit about backend, validate it. + fi_backend = get_flashinfer_moe_backend() + + if fi_backend == FlashinferMoeBackend.TENSORRT_LLM: + backend = Fp8MoeBackend.FLASHINFER_TRTLLM + supported, reason = is_supported_config_trtllm( + config, weight_key, activation_key, activation_format + ) + if supported: + logger.info_once(_make_log_backend(backend)) + return backend, None + else: + raise ValueError(_make_log_unsupported(backend, reason)) + + elif fi_backend == FlashinferMoeBackend.CUTLASS: + backend = Fp8MoeBackend.FLASHINFER_CUTLASS + return _return_or_raise( + backend, config, weight_key, activation_key, activation_format + ) + + else: + assert fi_backend == FlashinferMoeBackend.CUTEDSL + raise ValueError("FlashInfer MaskedGEMM not supported for FP8") + + else: + # If the user is not explicit about the backend, try both. + for backend in [ + Fp8MoeBackend.FLASHINFER_TRTLLM, + Fp8MoeBackend.FLASHINFER_CUTLASS, + ]: + if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: + k_cls = None + supported, reason = is_supported_config_trtllm( + config, + weight_key, + activation_key, + activation_format, + ) + else: + k_cls = backend_to_kernel_cls(backend) + supported, reason = k_cls.is_supported_config( + k_cls, + config, + weight_key, + activation_key, + activation_format, + ) + + if supported: + logger.info_once(_make_log_backend(backend), scope="local") + return backend, k_cls + else: + logger.debug_once( + _make_log_unsupported(backend, reason), scope="local" + ) + + raise NotImplementedError( + "Found VLLM_USE_FLASHINFER_MOE_FP8=1, but no " + "FlashInfer FP8 MoE backend supports the configuration." + ) + + # Handle explicit DeepGEMM FP8 configuration. + if envs.is_set("VLLM_USE_DEEP_GEMM") or envs.is_set("VLLM_MOE_USE_DEEP_GEMM"): + if not envs.VLLM_USE_DEEP_GEMM or not envs.VLLM_MOE_USE_DEEP_GEMM: + AVAILABLE_BACKENDS.remove(Fp8MoeBackend.DEEPGEMM) + AVAILABLE_BACKENDS.remove(Fp8MoeBackend.BATCHED_DEEPGEMM) + else: + backend = ( + Fp8MoeBackend.DEEPGEMM + if activation_format == mk.FusedMoEActivationFormat.Standard + else Fp8MoeBackend.BATCHED_DEEPGEMM + ) + return _return_or_raise( + backend, config, weight_key, activation_key, activation_format + ) + + # Handle explicit MARLIN FP8 configuration. + if envs.VLLM_TEST_FORCE_FP8_MARLIN: + backend = Fp8MoeBackend.MARLIN + return _return_or_raise( + backend, config, weight_key, activation_key, activation_format + ) + + # Handle explicit AITER FP8 configuration. + if envs.is_set("VLLM_ROCM_USE_AITER") or envs.is_set("VLLM_ROCM_USE_AITER_MOE"): + if not envs.VLLM_ROCM_USE_AITER or not envs.VLLM_ROCM_USE_AITER_MOE: + AVAILABLE_BACKENDS.remove(Fp8MoeBackend.AITER) + else: + backend = Fp8MoeBackend.AITER + return _return_or_raise( + backend, config, weight_key, activation_key, activation_format + ) + + if not allow_vllm_cutlass: + AVAILABLE_BACKENDS.remove(Fp8MoeBackend.VLLM_CUTLASS) + AVAILABLE_BACKENDS.remove(Fp8MoeBackend.BATCHED_VLLM_CUTLASS) + + # Select kernels in order of backend. + for backend in AVAILABLE_BACKENDS: + if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: + k_cls = None + supported, reason = is_supported_config_trtllm( + config, + weight_key, + activation_key, + activation_format, + ) + else: + k_cls = backend_to_kernel_cls(backend) + supported, reason = k_cls.is_supported_config( + k_cls, + config, + weight_key, + activation_key, + activation_format, + ) + + if supported: + logger.info_once(_make_log_backend(backend), scope="local") + return backend, k_cls + else: + logger.debug_once(_make_log_unsupported(backend, reason), scope="local") + + raise NotImplementedError( + "No FP8 MoE backend supports the deployment configuration." + ) def convert_to_fp8_moe_kernel_format( @@ -166,7 +346,7 @@ def convert_to_fp8_moe_kernel_format( w2_input_scale: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: block_quant = hasattr(layer, "weight_block_size") - if fp8_backend == Fp8MoeBackend.DEEPGEMM: + if fp8_backend in [Fp8MoeBackend.DEEPGEMM, Fp8MoeBackend.BATCHED_DEEPGEMM]: assert block_quant w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_deepgemm( w13, @@ -199,6 +379,14 @@ def convert_to_fp8_moe_kernel_format( w2_input_scale=w2_input_scale, is_trtllm=(fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM), ) + else: + if fp8_backend not in [ + Fp8MoeBackend.TRITON, + Fp8MoeBackend.BATCHED_TRITON, + Fp8MoeBackend.VLLM_CUTLASS, + Fp8MoeBackend.BATCHED_VLLM_CUTLASS, + ]: + raise ValueError(f"Unsupported FP8 MoE backend: {fp8_backend.value}") return w13, w2, w13_scale, w2_scale @@ -210,6 +398,8 @@ def make_fp8_moe_quant_config( a1_scale: torch.Tensor | None, a2_scale: torch.Tensor | None, block_shape: list[int] | None = None, + per_act_token_quant: bool = False, + per_out_ch_quant: bool = False, ) -> FusedMoEQuantConfig | None: """ Create FusedMoEQuantConfig for the specifed FP8 Backend. @@ -262,102 +452,76 @@ def make_fp8_moe_quant_config( a1_scale=a1_scale, a2_scale=a2_scale, block_shape=block_shape, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_out_ch_quant, ) +def make_fp8_moe_kernel_for_mkm( + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], + prepare_finalize: mk.FusedMoEPrepareAndFinalize, +) -> mk.FusedMoEPermuteExpertsUnpermute: + if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts: + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() + assert max_num_tokens_per_rank is not None + experts = experts_cls( + moe_config=moe_config, + quant_config=quant_config, + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), + ) + else: + experts = experts_cls( + moe_config=moe_config, + quant_config=quant_config, + ) + + logger.debug_once("Using %s", experts.__class__.__name__) + return experts + + def make_fp8_moe_kernel( - layer: torch.nn.Module, moe_quant_config: FusedMoEQuantConfig, moe_config: FusedMoEConfig, fp8_backend: Fp8MoeBackend, + experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], ) -> tuple[mk.FusedMoEModularKernel, bool]: - # Delayed import is required since the oracle is imported - # by CPU backends which cannot import all of these experts. - # TODO: update the experts to make this not happen. - from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, + # TODO(rob): unify after we merge tp and dp/ep. + if ( + moe_config.moe_parallel_config.use_all2all_kernels + and moe_config.moe_parallel_config.all2all_backend + not in ["allgather_reducescatter", "naive"] + ): + raise ValueError( + "Fp8 Oracle should not create non-naive A2A P/F. " + "This should happen via the ModularKernelMethod." + ) + + # Create Prepare/Finalize. + prepare_finalize = MoEPrepareAndFinalizeNoEP( + defer_input_quant=experts_cls.expects_unquantized_inputs( + moe_config, moe_quant_config + ), ) - # NOTE(rob): this is a WIP refactor. We are first migrating - # all of the kernels in the TP case to use mk. Once this is - # done, then we will initialzie the TP case and DP/EP case - # via the same code path (i.e. via maybe_init_modular_kernel). - # NOTE(rob): in progress migrating all into this format. - use_inplace = True - if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts, - ) + # Create Experts. + experts = experts_cls( + moe_config=moe_config, + quant_config=moe_quant_config, + ) - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP( - defer_input_quant=moe_quant_config.is_block_quantized - ), - FlashInferExperts( - out_dtype=layer.orig_dtype, - quant_config=moe_quant_config, - ep_rank=moe_config.ep_rank, - ep_size=moe_config.ep_size, - tp_rank=moe_config.tp_rank, - tp_size=moe_config.tp_size, - use_dp=(moe_config.dp_size > 1), - use_deepseek_fp8_block_scale=moe_quant_config.is_block_quantized, - ), - ) - use_inplace = False + # NOTE(rob): we only want the mk to control the shared_expert + # if using all2all (for SBO). bnell is making this explict in + # the new MoE runner class. + kernel = mk.FusedMoEModularKernel( + prepare_finalize, + experts, + shared_experts=None, + moe_parallel_config=moe_config.moe_parallel_config, + ) - elif fp8_backend == Fp8MoeBackend.AITER: - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - AiterExperts, - ) - - kernel = mk.FusedMoEModularKernel( - # TODO: make defer_input_quant an attr of the AiterExperts - MoEPrepareAndFinalizeNoEP(defer_input_quant=True), - AiterExperts(quant_config=moe_quant_config), - ) - elif fp8_backend == Fp8MoeBackend.MARLIN: - from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - MarlinExperts, - ) - - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - MarlinExperts(quant_config=moe_quant_config), - ) - elif fp8_backend == Fp8MoeBackend.VLLM_CUTLASS: - from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import ( - TritonOrCutlassExperts, - ) - - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - TritonOrCutlassExperts( - out_dtype=moe_config.in_dtype, - e=layer.local_num_experts, - n=layer.intermediate_size_per_partition, - k=layer.hidden_size, - device=layer.w13_weight.device, - quant_config=moe_quant_config, - ), - ) - elif fp8_backend == Fp8MoeBackend.DEEPGEMM: - from vllm.model_executor.layers.fused_moe import ( - TritonOrDeepGemmExperts, - ) - - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - TritonOrDeepGemmExperts(quant_config=moe_quant_config), - ) - else: - from vllm.model_executor.layers.fused_moe.fused_moe import ( - TritonExperts, - ) - - assert fp8_backend == Fp8MoeBackend.TRITON - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - TritonExperts(quant_config=moe_quant_config), - ) - return kernel, use_inplace + # TODO(rob): update inplace logic to be part of the kernel. + inplace = fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS + return kernel, inplace diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index f2d69cf09..897631e23 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -14,21 +14,11 @@ from vllm.model_executor.layers.fused_moe.config import ( nvfp4_moe_quant_config, nvfp4_w4a16_moe_quant_config, ) -from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - CutlassExpertsFp4, -) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts, -) -from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - MarlinExperts, -) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP, ) from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - is_flashinfer_fp4_cutedsl_moe_available, - is_flashinfer_fp4_cutlass_moe_available, + is_supported_config_trtllm, prepare_nvfp4_moe_layer_for_fi_or_cutlass, ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( @@ -36,27 +26,26 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( get_flashinfer_moe_backend, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - is_fp4_marlin_supported, prepare_nvfp4_moe_layer_for_marlin, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - cutlass_fp4_supported, + QuantKey, ) logger = init_logger(__name__) class NvFp4MoeBackend(Enum): - FLASHINFER_CUTLASS = "FlashInfer CUTLASS" - FLASHINFER_TRTLLM = "FlashInfer TRTLLM" - FLASHINFER_CUTEDSL = "FlashInfer CUTEDSL" - VLLM_CUTLASS = "vLLM CUTASS" - MARLIN = "vLLM MARLIN" + FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM" + FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS" + FLASHINFER_CUTEDSL = "FLASHINFER_CUTEDSL" + VLLM_CUTLASS = "VLLM_CUTLASS" + MARLIN = "MARLIN" FLASHINFER_NVFP4_MOE_BACKENDS = [ - NvFp4MoeBackend.FLASHINFER_CUTLASS, NvFp4MoeBackend.FLASHINFER_TRTLLM, + NvFp4MoeBackend.FLASHINFER_CUTLASS, NvFp4MoeBackend.FLASHINFER_CUTEDSL, ] @@ -72,44 +61,208 @@ def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool: # of all experts in Expert Parallel Mode when all experts are not # on the same rank. - return backend in [ - NvFp4MoeBackend.FLASHINFER_CUTLASS, + return backend in FLASHINFER_NVFP4_MOE_BACKENDS + + +def backend_to_kernel_cls( + backend: NvFp4MoeBackend, +) -> type[mk.FusedMoEPermuteExpertsUnpermute]: + if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: + raise NotImplementedError( + "FLASHINFER_TRTLLM doesn't support Modular Kernel Interface" + ) + + elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts, + ) + + return FlashInferExperts + + elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL: + from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( + FlashInferCuteDSLExperts, + ) + + return FlashInferCuteDSLExperts + + elif backend == NvFp4MoeBackend.VLLM_CUTLASS: + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + CutlassExpertsFp4, + ) + + return CutlassExpertsFp4 + + elif backend == NvFp4MoeBackend.MARLIN: + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + MarlinExperts, + ) + + return MarlinExperts + else: + raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}") + + +def select_nvfp4_moe_backend( + config: FusedMoEConfig, + weight_key: QuantKey | None, + activation_key: QuantKey | None, +) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]: + """ + Select the primary NvFP4 MoE backend + Note: Shape-specific fallbacks may still occur at runtime. + """ + + # NOTE: the kernels are selected in the following order. + AVAILABLE_BACKENDS = [ NvFp4MoeBackend.FLASHINFER_TRTLLM, + NvFp4MoeBackend.FLASHINFER_CUTEDSL, + NvFp4MoeBackend.FLASHINFER_CUTLASS, + NvFp4MoeBackend.VLLM_CUTLASS, + NvFp4MoeBackend.MARLIN, ] + # NOTE(rob): this is kind of a hack. We need to peak into + # the prepare-finalize selection to determine if we are using + # the batched or standard expert format. + use_batched = ( + config.moe_parallel_config.use_deepep_ll_kernels + or config.moe_parallel_config.use_pplx_kernels + ) + activation_format = ( + mk.FusedMoEActivationFormat.BatchedExperts + if use_batched + else mk.FusedMoEActivationFormat.Standard + ) -def select_nvfp4_moe_backend() -> NvFp4MoeBackend: def _make_log_backend(backend: NvFp4MoeBackend): - return f"Using {backend.value} backend for NvFp4 MoE" - - if cutlass_fp4_supported() and not envs.VLLM_TEST_FORCE_FP8_MARLIN: - allow_flashinfer = ( - is_flashinfer_fp4_cutlass_moe_available() - or is_flashinfer_fp4_cutedsl_moe_available() + available_backend_strs = [b.value for b in AVAILABLE_BACKENDS] + return ( + f"Using '{backend.value}' NvFp4 MoE backend out " + f"of potential backends: {available_backend_strs}." ) - if allow_flashinfer and envs.VLLM_USE_FLASHINFER_MOE_FP4: - backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()] + + def _make_log_unsupported(backend: NvFp4MoeBackend, reason: str | None) -> str: + if reason: + return ( + f"NvFp4 MoE backend '{backend.value}' does not support the " + f"deployment configuration since {reason}." + ) else: - backend = NvFp4MoeBackend.VLLM_CUTLASS - elif is_fp4_marlin_supported(): - backend = NvFp4MoeBackend.MARLIN - else: - raise ValueError("No NvFp4 kernel backend available for NvFp4 MoE.") + return ( + f"NvFp4 MoE backend '{backend.value}' does not support the " + "deployment configuration." + ) - # Log warning if FI backend requested but not available. - if ( - backend not in FLASHINFER_NVFP4_MOE_BACKENDS - and envs.VLLM_USE_FLASHINFER_MOE_FP4 - ): - logger.warning_once( - "Requested FlashInfer backend for NvFp4 MoE, but it's not available. " - "Falling back to %s for NvFp4 MoE", - backend.value, - scope="local", + def _return_or_raise( + backend: NvFp4MoeBackend, + config: FusedMoEConfig, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + activation_format: mk.FusedMoEActivationFormat, + ) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]: + k_cls = backend_to_kernel_cls(backend) + supported, reason = k_cls.is_supported_config( + k_cls, config, weight_key, activation_key, activation_format ) - else: - logger.info_once(_make_log_backend(backend), scope="local") - return backend + if supported: + logger.info_once(_make_log_backend(backend)) + return backend, k_cls + raise ValueError(_make_log_unsupported(backend, reason)) + + if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"): + if not envs.VLLM_USE_FLASHINFER_MOE_FP4: + # If the user rejects FlashInfer remove those backends. + for b in FLASHINFER_NVFP4_MOE_BACKENDS: + AVAILABLE_BACKENDS.remove(b) + + elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"): + # If user is explicit about backend, validate it. + fi_backend = get_flashinfer_moe_backend() + + if fi_backend == FlashinferMoeBackend.TENSORRT_LLM: + backend = NvFp4MoeBackend.FLASHINFER_TRTLLM + supported, reason = is_supported_config_trtllm( + config, weight_key, activation_key, activation_format + ) + if supported: + logger.info_once(_make_log_backend(backend)) + return backend, None + else: + raise ValueError(_make_log_unsupported(backend, reason)) + else: + backend = fi_2_vllm_backend_map[fi_backend] + return _return_or_raise( + backend, config, weight_key, activation_key, activation_format + ) + else: + # If the user is not explicit about the backend, try each. + for backend in FLASHINFER_NVFP4_MOE_BACKENDS: + if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: + k_cls = None + supported, reason = is_supported_config_trtllm( + config, + weight_key, + activation_key, + activation_format, + ) + else: + k_cls = backend_to_kernel_cls(backend) + supported, reason = k_cls.is_supported_config( + k_cls, + config, + weight_key, + activation_key, + activation_format, + ) + if supported: + logger.info_once(_make_log_backend(backend), scope="local") + return backend, None + else: + logger.debug_once( + _make_log_unsupported(backend, reason), scope="local" + ) + + raise NotImplementedError( + "Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no " + "FlashInfer NVFP4 MoE backend supports the configuration." + ) + + if envs.VLLM_TEST_FORCE_FP8_MARLIN: + backend = NvFp4MoeBackend.MARLIN + return _return_or_raise( + backend, config, weight_key, activation_key, activation_format + ) + + # Select kernels in order of backend. + for backend in AVAILABLE_BACKENDS: + if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: + k_cls = None # type: ignore[assignment] + supported, reason = is_supported_config_trtllm( + config, + weight_key, + activation_key, + activation_format, + ) + else: + k_cls = backend_to_kernel_cls(backend) + supported, reason = k_cls.is_supported_config( + k_cls, + config, + weight_key, + activation_key, + activation_format, + ) + + if supported: + logger.info_once(_make_log_backend(backend), scope="local") + return backend, k_cls + else: + logger.debug_once(_make_log_unsupported(backend, reason), scope="local") + + raise NotImplementedError( + "No NvFp4 MoE backend supports the deployment configuration." + ) def convert_to_nvfp4_moe_kernel_format( @@ -238,55 +391,69 @@ def make_nvfp4_moe_quant_config( ) -def make_nvfp4_moe_kernel( - backend: NvFp4MoeBackend, - quant_config: FusedMoEQuantConfig, +def make_nvfp4_moe_kernel_for_mkm( moe_config: FusedMoEConfig, -) -> mk.FusedMoEModularKernel | None: - assert moe_config.dp_size == 1 - - UNSUPPORTED_BACKENDS = [ - # TRTLLM does not use the modular kernl abstraction. - NvFp4MoeBackend.FLASHINFER_TRTLLM, - # CUTEDSL is used with BATCHED (masked) format only. - # TODO: add here once we support dp/ep via the oracle. - NvFp4MoeBackend.FLASHINFER_CUTEDSL, - ] - - if backend in UNSUPPORTED_BACKENDS: - return None - - elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: - return mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(defer_input_quant=True), - FlashInferExperts( - out_dtype=moe_config.in_dtype, - quant_config=quant_config, - ep_rank=moe_config.ep_rank, - ep_size=moe_config.ep_size, - tp_rank=moe_config.tp_rank, - tp_size=moe_config.tp_size, - use_dp=False, - use_deepseek_fp8_block_scale=False, - ), + quant_config: FusedMoEQuantConfig, + experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], + prepare_finalize: mk.FusedMoEPrepareAndFinalize, +) -> mk.FusedMoEPermuteExpertsUnpermute: + if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts: + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() + assert max_num_tokens_per_rank is not None + experts = experts_cls( + moe_config=moe_config, + quant_config=quant_config, + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), ) - - elif backend == NvFp4MoeBackend.VLLM_CUTLASS: - return mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(defer_input_quant=True), - CutlassExpertsFp4( - out_dtype=moe_config.in_dtype, - # TODO(rob): see what impact this has on expert map? - max_experts_per_worker=moe_config.num_experts, - quant_config=quant_config, - ), - ) - - elif backend == NvFp4MoeBackend.MARLIN: - return mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - MarlinExperts(quant_config=quant_config), - ) - else: - raise ValueError(f"Unknown NvFp4 MoE backend: {backend}") + experts = experts_cls( + moe_config=moe_config, + quant_config=quant_config, + ) + + logger.debug_once("Using %s", experts.__class__.__name__) + return experts + + +def make_nvfp4_moe_kernel( + moe_quant_config: FusedMoEQuantConfig, + moe_config: FusedMoEConfig, + experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], +) -> mk.FusedMoEModularKernel: + # TODO(rob): unify after we merge tp and dp/ep. + if ( + moe_config.moe_parallel_config.use_all2all_kernels + and moe_config.moe_parallel_config.all2all_backend + not in ["allgather_reducescatter", "naive"] + ): + raise ValueError( + "NvFP4 Oracle should not create non-naive A2A P/F. " + "This should happen via the ModularKernelMethod." + ) + + # Create Prepare/Finalize. + prepare_finalize = MoEPrepareAndFinalizeNoEP( + defer_input_quant=experts_cls.expects_unquantized_inputs( + moe_config, moe_quant_config + ), + ) + + # Create Experts. + experts = experts_cls( + moe_config=moe_config, + quant_config=moe_quant_config, + ) + + # NOTE(rob): we only want the mk to control the shared_expert + # if using all2all (for SBO). bnell is making this explict in + # the new MoE runner class. + kernel = mk.FusedMoEModularKernel( + prepare_finalize, + experts, + shared_experts=None, + moe_parallel_config=moe_config.moe_parallel_config, + ) + + # TODO(rob): update inplace logic to be part of the kernel. + return kernel diff --git a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py index 138bfac28..14c3f84e6 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py +++ b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py @@ -123,7 +123,6 @@ def convert_to_unquantized_kernel_format( def make_unquantized_moe_kernel( - layer: torch.nn.Module, backend: UnquantizedMoeBackend, quant_config: FusedMoEQuantConfig, moe_config: FusedMoEConfig, @@ -141,12 +140,8 @@ def make_unquantized_moe_kernel( kernel = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), FlashInferExperts( - out_dtype=layer.params_dtype, + moe_config=moe_config, quant_config=quant_config, - tp_rank=moe_config.moe_parallel_config.tp_rank, - tp_size=moe_config.moe_parallel_config.tp_size, - ep_rank=moe_config.moe_parallel_config.ep_rank, - ep_size=moe_config.moe_parallel_config.ep_size, ), ) use_inplace = False @@ -157,13 +152,19 @@ def make_unquantized_moe_kernel( kernel = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), - AiterExperts(quant_config), + AiterExperts( + moe_config=moe_config, + quant_config=quant_config, + ), ) elif backend == UnquantizedMoeBackend.TRITON: from vllm.model_executor.layers.fused_moe import TritonExperts kernel = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), - TritonExperts(quant_config), + TritonExperts( + moe_config=moe_config, + quant_config=quant_config, + ), ) return kernel, use_inplace diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index b78794c6b..6d8e5ca79 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -9,11 +9,21 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEParallelConfig, FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8Dynamic128Sym, + kFp8DynamicTensorSym, + kFp8DynamicTokenSym, + kFp8Static128BlockSym, + kFp8StaticChannelSym, + kFp8StaticTensorSym, +) class QuantMethod(IntEnum): @@ -269,17 +279,49 @@ def rocm_aiter_fused_experts( class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self, quant_config): - super().__init__(quant_config) + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, - ) + @staticmethod + def expects_unquantized_inputs( + fused_moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig + ) -> bool: + # AITER fused MoE kernels handle input quantization internally. + return True + + @staticmethod + def _supports_current_device() -> bool: + return rocm_aiter_ops.is_fused_moe_enabled() + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + # TODO(rob): AITER also supports MXFP4, which is not + # yet supported via an Oracle. Once it is, we will add + # MXFP4 to this list. + SUPPORTED_W_A = [ + (None, None), + (kFp8Static128BlockSym, kFp8Dynamic128Sym), + (kFp8StaticTensorSym, kFp8StaticTensorSym), + (kFp8StaticTensorSym, kFp8DynamicTensorSym), + (kFp8StaticChannelSym, kFp8DynamicTokenSym), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + @staticmethod + def _supports_activation(activation: str) -> bool: + return activation in ["silu", "gelu"] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True def supports_expert_map(self): return True diff --git a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py index a19dfb62b..17c59352f 100644 --- a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py +++ b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py @@ -34,6 +34,11 @@ class CustomRoutingRouter(BaseRouter): @property def routing_method_type(self) -> RoutingMethodType: + from vllm.model_executor.models.llama4 import Llama4MoE + + # NOTE: FLASHINFER_TRTLLM support the Llama4 router. + if self.custom_routing_function == Llama4MoE.custom_routing_function: + return RoutingMethodType.Llama4 return RoutingMethodType.Custom def _compute_routing( diff --git a/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py b/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py index e5b6de02f..1c908a2b4 100644 --- a/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py +++ b/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py @@ -261,7 +261,6 @@ class GroupedTopKRouter(BaseRouter): num_fused_shared_experts: int = 0, enable_eplb: bool = False, indices_type_getter: Callable[[], torch.dtype | None] | None = None, - routing_method_type: RoutingMethodType | None = None, ): super().__init__( top_k=top_k, @@ -278,13 +277,12 @@ class GroupedTopKRouter(BaseRouter): self.e_score_correction_bias = e_score_correction_bias self.num_fused_shared_experts = num_fused_shared_experts - # Determine routing method type - if routing_method_type is not None: - self._routing_method_type = routing_method_type - elif scoring_func == "sigmoid": + if scoring_func == "sigmoid": self._routing_method_type = RoutingMethodType.DeepSeekV3 else: - self._routing_method_type = RoutingMethodType.TopK + # NOTE: this prohibits the FLASHINFER_TRTLLM kernels from + # being selected, since they only support DeepSeek-style. + self._routing_method_type = RoutingMethodType.Unspecified @property def routing_method_type(self) -> RoutingMethodType: diff --git a/vllm/model_executor/layers/fused_moe/router/router_factory.py b/vllm/model_executor/layers/fused_moe/router/router_factory.py index 8818373d8..cbe294e6b 100644 --- a/vllm/model_executor/layers/fused_moe/router/router_factory.py +++ b/vllm/model_executor/layers/fused_moe/router/router_factory.py @@ -6,7 +6,6 @@ import torch import vllm.envs as envs from vllm.distributed.eplb.eplb_state import EplbLayerState -from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter from vllm.model_executor.layers.fused_moe.router.custom_routing_router import ( CustomRoutingRouter, @@ -36,7 +35,6 @@ def create_fused_moe_router( global_num_experts: int, renormalize: bool = True, indices_type_getter: Callable[[], torch.dtype | None] | None = None, - routing_method_type: RoutingMethodType | None = None, # grouped topk parameters use_grouped_topk: bool = False, num_expert_group: int | None = None, @@ -128,7 +126,6 @@ def create_fused_moe_router( num_fused_shared_experts=num_fused_shared_experts, enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, - routing_method_type=routing_method_type, ) router.capture = capture return router diff --git a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py index 09d5e45c1..f537f2f99 100644 --- a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py @@ -5,7 +5,10 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts @@ -17,19 +20,22 @@ class TritonOrCutlassExperts(FallbackExperts): def __init__( self, - e: int, - n: int, - k: int, - out_dtype: torch.dtype | None, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, - device: torch.dtype, ): self.is_sm100 = current_platform.has_device_capability(100) super().__init__( - experts=CutlassExpertsFp8(e, n, k, out_dtype, quant_config, device), - fallback_experts=TritonExperts(quant_config), + experts=CutlassExpertsFp8(moe_config, quant_config), + fallback_experts=TritonExperts(moe_config, quant_config), ) + @staticmethod + def get_clses() -> tuple[ + type[mk.FusedMoEPermuteExpertsUnpermute], + type[mk.FusedMoEPermuteExpertsUnpermute], + ]: + return (CutlassExpertsFp8, TritonExperts) + def workspace_shapes( self, M: int, diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 55b1e1211..7e41269dc 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -4,7 +4,10 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts, _valid_deep_gemm, @@ -20,12 +23,19 @@ from vllm.utils.deep_gemm import ( class TritonOrDeepGemmExperts(FallbackExperts): """DeepGemm with fallback to Triton for low latency shapes.""" - def __init__(self, quant_config: FusedMoEQuantConfig): + def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig): super().__init__( - experts=DeepGemmExperts(quant_config), - fallback_experts=TritonExperts(quant_config), + experts=DeepGemmExperts(moe_config, quant_config), + fallback_experts=TritonExperts(moe_config, quant_config), ) + @staticmethod + def get_clses() -> tuple[ + type[mk.FusedMoEPermuteExpertsUnpermute], + type[mk.FusedMoEPermuteExpertsUnpermute], + ]: + return (DeepGemmExperts, TritonExperts) + def workspace_shapes( self, M: int, diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py index c46f59564..29a3e9003 100644 --- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -6,37 +6,73 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, + FusedMoEParallelConfig, FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, +) class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - moe: FusedMoEConfig, + moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, max_capture_size, ): - super().__init__(quant_config) - self.moe = moe + super().__init__(moe_config, quant_config) self.gemm1_alpha = gemm1_alpha self.gemm1_beta = gemm1_beta self.gemm1_clamp_limit = gemm1_clamp_limit self.max_capture_size = max_capture_size - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return ( - mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard, + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + @staticmethod + def _supports_current_device() -> bool: + raise NotImplementedError( + "TrtLlmGenExperts is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + raise NotImplementedError( + "TrtLlmGenExperts is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + raise NotImplementedError( + "TrtLlmGenExperts is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_activation(activation: str) -> bool: + raise NotImplementedError( + "TrtLlmGenExperts is not yet used by an Oracle. " + "This method should not be called." + ) + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + raise NotImplementedError( + "TrtLlmGenExperts is not yet used by an Oracle. " + "This method should not be called." ) def supports_chunking(self) -> bool: @@ -86,7 +122,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): topk = topk_ids.size(-1) local_num_experts = w1.size(0) intermediate_size = w2.size(1) - local_expert_offset = self.moe.ep_rank * local_num_experts + local_expert_offset = self.moe_config.ep_rank * local_num_experts x_quant = hidden_states x_scale = a1q_scale diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index f8489ab06..2581a1e56 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -96,13 +96,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ): logger.debug("BatchedTritonExperts %s", self.moe) return BatchedTritonExperts( + moe_config=self.moe, + quant_config=self.moe_quant_config, max_num_tokens=self.moe.max_num_tokens, num_dispatchers=prepare_finalize.num_dispatchers(), - quant_config=self.moe_quant_config, ) else: logger.debug("TritonExperts %s", self.moe) - return TritonExperts(self.moe_quant_config) + return TritonExperts( + moe_config=self.moe, + quant_config=self.moe_quant_config, + ) def create_weights( self, @@ -192,7 +196,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): assert self.moe_quant_config is not None self.kernel, self.use_inplace = make_unquantized_moe_kernel( - layer=layer, backend=self.unquantized_backend, quant_config=self.moe_quant_config, moe_config=self.moe, diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index b1fb67208..dfdfc7ea2 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -739,6 +739,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): return BatchedMarlinExperts( max_num_tokens=max_num_tokens_per_rank, num_dispatchers=prepare_finalize.num_dispatchers(), + moe_config=self.moe, quant_config=self.moe_quant_config, w13_g_idx=w13_g_idx, w2_g_idx=w2_g_idx, @@ -749,6 +750,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): else: # Standard Marlin experts for AWQ return MarlinExperts( + moe_config=self.moe, quant_config=self.moe_quant_config, w13_g_idx=w13_g_idx, w2_g_idx=w2_g_idx, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 13a123ba6..b48acdcf3 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -19,7 +19,6 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEActivationFormat, - FusedMoEConfig, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, FusedMoERouter, @@ -27,9 +26,9 @@ from vllm.model_executor.layers.fused_moe import ( UnquantizedFusedMoEMethod, ) from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEQuantConfig, - fp8_w8a8_moe_quant_config, - fp8_w8a16_moe_quant_config, + RoutingMethodType, int4_w4a16_moe_quant_config, int4_w4afp8_moe_quant_config, int8_w8a8_moe_quant_config, @@ -45,15 +44,17 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( Fp8MoeBackend, convert_to_fp8_moe_kernel_format, make_fp8_moe_kernel, + make_fp8_moe_kernel_for_mkm, + make_fp8_moe_quant_config, select_fp8_moe_backend, ) from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( - FLASHINFER_NVFP4_MOE_BACKENDS, NvFp4MoeBackend, convert_to_nvfp4_moe_kernel_format, is_global_sf_supported_for_nvfp4_backend, make_mxfp4_moe_quant_config, make_nvfp4_moe_kernel, + make_nvfp4_moe_kernel_for_mkm, make_nvfp4_moe_quant_config, select_nvfp4_moe_backend, ) @@ -62,10 +63,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compress WNA16_SUPPORTED_TYPES_MAP, ) from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - build_flashinfer_fp4_cutlass_moe_prepare_finalize, flashinfer_trtllm_fp4_moe, flashinfer_trtllm_fp4_routed_moe, - select_nvfp4_gemm_impl, +) +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + apply_fi_trtllm_fp8_per_tensor_moe, + build_flashinfer_fp8_cutlass_moe_prepare_finalize, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( process_fp8_input_tensor_strategy_moe, @@ -79,12 +82,18 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_moe_permute_scales, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - is_fp4_marlin_supported, prepare_moe_fp4_layer_for_marlin, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( convert_bf16_scales_to_fp8, convert_packed_uint4b8_to_signed_int4_inplace, + kFp8Dynamic128Sym, + kFp8DynamicTokenSym, + kFp8Static128BlockSym, + kFp8StaticChannelSym, + kFp8StaticTensorSym, + kNvfp4Dynamic, + kNvfp4Static, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( normalize_e4m3fn_to_e4m3fnuz, @@ -200,7 +209,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): f"or None for NVFP4A16, found {input_quant}", ) return CompressedTensorsW4A4Nvfp4MoEMethod( - layer.moe_config, layer_name, use_marlin=input_quant is None + layer.moe_config, layer_name, use_a16=(input_quant is None) ) elif ( quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) @@ -234,6 +243,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): super().__init__(moe) self.group_size = 32 self.mxfp4_backend = NvFp4MoeBackend.MARLIN + self.experts_cls = MarlinExperts self.kernel: mk.FusedMoEModularKernel | None = None def create_weights( @@ -327,9 +337,9 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config is not None: self.kernel = make_nvfp4_moe_kernel( - backend=self.mxfp4_backend, - quant_config=self.moe_quant_config, + moe_quant_config=self.moe_quant_config, moe_config=self.moe, + experts_cls=self.experts_cls, ) def apply( @@ -368,34 +378,30 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): self, moe: FusedMoEConfig, layer_name: str | None = None, - use_marlin: bool = False, + use_a16: bool = False, ): super().__init__(moe) self.group_size = 16 - if use_marlin: - if is_fp4_marlin_supported(): - self.nvfp4_backend = NvFp4MoeBackend.MARLIN - else: - raise ValueError( - "Marlin FP4 MoE kernel requested but not ", - "supported on current platform.", - ) - else: - self.nvfp4_backend = select_nvfp4_moe_backend() - # TODO: move this type of check into the oracle. - if not self.moe.is_act_and_mul and self.nvfp4_backend not in [ - NvFp4MoeBackend.FLASHINFER_CUTLASS, - NvFp4MoeBackend.MARLIN, - ]: - raise NotImplementedError( - "Non-gated activations are only supported by FlashInfer " - f"CUTLASS and Marlin NvFP4 MoE backends, not {self.nvfp4_backend}." - ) + # Select experts implementation. + self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend( + config=self.moe, + weight_key=kNvfp4Static, + activation_key=None if use_a16 else kNvfp4Dynamic, + ) + + # Delay creation of the kernel until after process-weights. + self.kernel: mk.FusedMoEModularKernel | None = None + self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.nvfp4_backend ) - self.kernel: mk.FusedMoEModularKernel | None = None + + @property + def topk_indices_dtype(self) -> torch.dtype | None: + if self.kernel is not None: + return self.kernel.prepare_finalize.topk_indices_dtype() + return None def create_weights( self, @@ -571,35 +577,40 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): layer.w13_input_scale = a13_scale layer.w2_input_scale = a2_scale - # Initialize the kernel that will be called in apply(). + # Setup modular kernel for TP case and naive DP/EP case. + # In non-naive DP/EP case, we will create a ModularKernelMethod. + # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel + # in both cases. self.moe_quant_config = self.get_fused_moe_quant_config(layer) - use_dp = self.moe.dp_size > 1 - if self.moe_quant_config is not None and not use_dp: + if self.moe_quant_config and ( + (not self.moe.moe_parallel_config.use_all2all_kernels) + or self.moe.moe_parallel_config.use_naive_all2all_kernels + ): + assert self.experts_cls is not None self.kernel = make_nvfp4_moe_kernel( - backend=self.nvfp4_backend, - quant_config=self.moe_quant_config, + moe_quant_config=self.moe_quant_config, moe_config=self.moe, + experts_cls=self.experts_cls, ) def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM] - if self.nvfp4_backend in UNSUPPORTED: + if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: return None elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: - # TP case: avoid convert to ModularKernelMethod - to be refactored. - if self.moe.dp_size == 1: + # For no-EP case, don't use the MKM framework. + if not self.moe.moe_parallel_config.use_all2all_kernels: return None - # For now, fp4 moe only works with the flashinfer dispatcher. - prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( - self.moe + + prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( + self.moe, + use_deepseek_fp8_block_scale=False, ) logger.debug_once("%s", prepare_finalize.__class__.__name__) return prepare_finalize - else: - return super().maybe_make_prepare_finalize(routing_tables) + return super().maybe_make_prepare_finalize(routing_tables) def select_gemm_impl( self, @@ -607,14 +618,13 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: assert self.moe_quant_config is not None - """Return the appropriate GEMM experts implementation.""" - experts = select_nvfp4_gemm_impl( - self.moe, - self.moe_quant_config, - allow_flashinfer=(self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS), + assert self.experts_cls is not None + return make_nvfp4_moe_kernel_for_mkm( + moe_config=self.moe, + quant_config=self.moe_quant_config, + experts_cls=self.experts_cls, + prepare_finalize=prepare_finalize, ) - logger.debug_once("Using %s", experts.__class__.__name__) - return experts def get_fused_moe_quant_config( self, layer: torch.nn.Module @@ -727,33 +737,41 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): "For FP8 Fused MoE layer, we require either per tensor or " "channelwise, dynamic per token quantization." ) - self.fp8_backend = select_fp8_moe_backend( - block_quant=self.block_quant, - tp_size=moe.tp_size, - with_lora_support=moe.is_lora_enabled, - is_act_and_mul=moe.is_act_and_mul, - # TODO(rob): enable selecting this externally. + + ct2vllm_weight = { + QuantizationStrategy.CHANNEL: kFp8StaticChannelSym, + QuantizationStrategy.TENSOR: kFp8StaticTensorSym, + QuantizationStrategy.BLOCK: kFp8Static128BlockSym, + } + ct2vllm_act = { + QuantizationStrategy.TOKEN: kFp8DynamicTokenSym, + QuantizationStrategy.TENSOR: ( + kFp8StaticTensorSym if self.static_input_scales else kFp8Dynamic128Sym + ), + } + weight_key = ct2vllm_weight[self.weight_quant.strategy] + if weight_key == kFp8Static128BlockSym: + activation_key = kFp8Dynamic128Sym + else: + activation_key = ct2vllm_act[self.input_quant.strategy] + + # Select Fp8 MoE backend + self.fp8_backend, self.experts_cls = select_fp8_moe_backend( + config=self.moe, + weight_key=weight_key, + activation_key=activation_key, allow_vllm_cutlass=True, ) - if self.fp8_backend != Fp8MoeBackend.MARLIN: - per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN - per_channel_quant = ( - self.weight_quant.strategy == QuantizationStrategy.CHANNEL - ) - if per_act_token != per_channel_quant: - raise NotImplementedError( - "For FP8 Fused MoE layers, per-token and per-channel must be " - "used together." - ) - # TODO(rob): hook this up in a follow up PR. - if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - raise NotImplementedError( - "FlashInfer TRTLLM backend not supported for compressed-tensors yet." - ) - self.disable_expert_map = False + # Delay creation of the kernel until after process-weights. self.kernel: mk.FusedMoEModularKernel | None = None + @property + def topk_indices_dtype(self) -> torch.dtype | None: + if self.kernel is not None: + return self.kernel.prepare_finalize.topk_indices_dtype() + return None + def create_weights( self, layer: torch.nn.Module, @@ -970,140 +988,75 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): replace_parameter(layer, "w13_weight_scale", w13_scale) replace_parameter(layer, "w2_weight_scale", w2_scale) + # Setup modular kernel for TP case and naive DP/EP case. + # In non-naive DP/EP case, we will create a ModularKernelMethod. + # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel + # in both cases. self.moe_quant_config = self.get_fused_moe_quant_config(layer) - if self.moe_quant_config: + if self.moe_quant_config and ( + (not self.moe.moe_parallel_config.use_all2all_kernels) + or self.moe.moe_parallel_config.use_naive_all2all_kernels + ): + assert self.experts_cls is not None self.kernel, self.use_inplace = make_fp8_moe_kernel( - layer=layer, moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, + experts_cls=self.experts_cls, ) def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]: + if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: return None - else: - return super().maybe_make_prepare_finalize(routing_tables) + elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: + # For no-EP case, don't use the MKM framework. + if not self.moe.moe_parallel_config.use_all2all_kernels: + return None + + prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( + self.moe, + use_deepseek_fp8_block_scale=self.block_quant, + ) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize + return super().maybe_make_prepare_finalize(routing_tables) def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: - # cutlass path assert self.moe_quant_config is not None - if self.fp8_backend == Fp8MoeBackend.VLLM_CUTLASS: - from vllm.model_executor.layers.fused_moe import ( - CutlassBatchedExpertsFp8, - CutlassExpertsFp8, - ) - - experts: FusedMoEPermuteExpertsUnpermute - - num_dispatchers = prepare_finalize.num_dispatchers() - - if ( - prepare_finalize.activation_format - == FusedMoEActivationFormat.BatchedExperts - ): - logger.debug("CutlassBatchedExpertsFp8(%s)", self.__class__.__name__) - experts = CutlassBatchedExpertsFp8( - max_experts_per_worker=self.moe.num_local_experts, - num_dispatchers=num_dispatchers, - out_dtype=self.moe.in_dtype, - e=layer.local_num_experts, - n=layer.intermediate_size_per_partition, - k=layer.hidden_size, - device=layer.w13_weight.device, - quant_config=self.moe_quant_config, - ) - else: - logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) - experts = CutlassExpertsFp8( - out_dtype=self.moe.in_dtype, - e=layer.local_num_experts, - n=layer.intermediate_size_per_partition, - k=layer.hidden_size, - device=layer.w13_weight.device, - quant_config=self.moe_quant_config, - ) - - # TODO(rob): investigate disable_expert_map - self.disable_expert_map = ( - num_dispatchers > 1 or not experts.supports_expert_map() - ) - - return experts - - from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - BatchedDeepGemmExperts, + assert self.experts_cls is not None + return make_fp8_moe_kernel_for_mkm( + moe_config=self.moe, + quant_config=self.moe_quant_config, + experts_cls=self.experts_cls, + prepare_finalize=prepare_finalize, ) - from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts, - ) - from vllm.model_executor.layers.fused_moe.fused_moe import ( - TritonExperts, - ) - from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( - TritonOrDeepGemmExperts, - ) - - assert self.fp8_backend not in [Fp8MoeBackend.AITER, Fp8MoeBackend.MARLIN] - - if ( - prepare_finalize.activation_format - == FusedMoEActivationFormat.BatchedExperts - ): - max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() - assert max_num_tokens_per_rank is not None - - if self.fp8_backend == Fp8MoeBackend.DEEPGEMM: - logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__) - return BatchedDeepGemmExperts( - max_num_tokens=max_num_tokens_per_rank, - num_dispatchers=prepare_finalize.num_dispatchers(), - quant_config=self.moe_quant_config, - ) - else: - logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) - return BatchedTritonExperts( - max_num_tokens=max_num_tokens_per_rank, - num_dispatchers=prepare_finalize.num_dispatchers(), - quant_config=self.moe_quant_config, - ) - - else: - if self.fp8_backend == Fp8MoeBackend.DEEPGEMM: - logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__) - return TritonOrDeepGemmExperts(self.moe_quant_config) - else: - logger.debug("TritonExperts(%s)", self.__class__.__name__) - return TritonExperts(self.moe_quant_config) def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - if self.fp8_backend == Fp8MoeBackend.MARLIN: - return fp8_w8a16_moe_quant_config( - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - block_shape=self.weight_block_size, - ) + w1_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + a1_scale = layer.w13_input_scale + a2_scale = layer.w2_input_scale - per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN - per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL - - return fp8_w8a8_moe_quant_config( - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - per_act_token_quant=per_act_token, - per_out_ch_quant=per_channel_quant, - block_shape=layer.weight_block_size, + return make_fp8_moe_quant_config( + fp8_backend=self.fp8_backend, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + per_act_token_quant=( + self.input_quant.strategy == QuantizationStrategy.TOKEN + ), + per_out_ch_quant=(self.input_quant.strategy == QuantizationStrategy.TOKEN), + block_shape=self.weight_block_size, ) def apply( @@ -1113,6 +1066,56 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: + if layer.enable_eplb: + raise NotImplementedError( + "EPLB not supported for `FlashInfer TRTLLM FP8 MoE`." + ) + assert layer.activation == "silu" + + if self.block_quant: + import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 + + e_score_correction_bias = ( + layer.e_score_correction_bias.to(x.dtype) + if layer.e_score_correction_bias is not None + else None + ) + routing_method_type = layer.routing_method_type + return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( + routing_logits=router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits, + routing_bias=e_score_correction_bias, + x=x, + w13_weight=layer.w13_weight, + w13_weight_scale_inv=layer.w13_weight_scale, + w2_weight=layer.w2_weight, + w2_weight_scale_inv=layer.w2_weight_scale, + global_num_experts=layer.global_num_experts, + top_k=layer.top_k, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + intermediate_size=layer.intermediate_size_per_partition, + expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + block_shape=self.weight_block_size, + routing_method_type=routing_method_type, + routed_scaling=layer.routed_scaling_factor, + ) + else: + return apply_fi_trtllm_fp8_per_tensor_moe( + layer=layer, + hidden_states=x, + router_logits=router_logits, + routing_bias=layer.e_score_correction_bias, + global_num_experts=layer.global_num_experts, + top_k=layer.top_k, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + ) + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, @@ -1130,7 +1133,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): global_num_experts=layer.global_num_experts, # TODO(rob): investigate the disable_expert_map introduced by: # https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501 - expert_map=None if self.disable_expert_map else layer.expert_map, + expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, ) @@ -1596,6 +1599,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): return BatchedMarlinExperts( max_num_tokens=max_num_tokens_per_rank, num_dispatchers=prepare_finalize.num_dispatchers(), + moe_config=self.moe, quant_config=self.moe_quant_config, w13_g_idx=layer.w13_weight_g_idx, w2_g_idx=layer.w2_weight_g_idx, @@ -1605,6 +1609,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ) else: return MarlinExperts( + moe_config=self.moe, quant_config=self.moe_quant_config, w13_g_idx=layer.w13_weight_g_idx, w2_g_idx=layer.w2_weight_g_idx, @@ -1854,7 +1859,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): layer.w13_weight = layer.w13_weight_packed layer.w2_weight = layer.w2_weight_packed - return TritonWNA16Experts(quant_config=self.moe_quant_config) + return TritonWNA16Experts( + moe_config=self.moe, quant_config=self.moe_quant_config + ) else: raise NotImplementedError( "TritonExperts requires Triton. " @@ -2467,6 +2474,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): c_strides2=self.a_strides1_c_strides2, s_strides1=self.s_strides1, s_strides2=self.s_strides2, + moe_config=self.moe, quant_config=self.moe_quant_config, group_size=self.group_size, ) @@ -2505,6 +2513,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): layer.w2_weight_packed, topk_weights, topk_ids, + moe_config=self.moe, quant_config=self.moe_quant_config, activation=layer.activation, global_num_experts=layer.global_num_experts, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5adcd09b0..d9fa02c6b 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -19,7 +19,6 @@ from vllm.model_executor.layers.batch_invariant import ( ) from vllm.model_executor.layers.fused_moe import ( FusedMoE, - FusedMoEActivationFormat, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, @@ -35,6 +34,7 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( Fp8MoeBackend, convert_to_fp8_moe_kernel_format, make_fp8_moe_kernel, + make_fp8_moe_kernel_for_mkm, make_fp8_moe_quant_config, select_fp8_moe_backend, ) @@ -55,7 +55,6 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_fi_trtllm_fp8_per_tensor_moe, build_flashinfer_fp8_cutlass_moe_prepare_finalize, - select_cutlass_fp8_gemm_impl, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, @@ -79,8 +78,10 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, is_layer_skipped, + kFp8Dynamic128Sym, kFp8DynamicTensorSym, kFp8DynamicTokenSym, + kFp8Static128BlockSym, kFp8StaticTensorSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -658,38 +659,36 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.weight_scale_name = ( "weight_scale_inv" if self.block_quant else "weight_scale" ) - self.fp8_backend = select_fp8_moe_backend( - block_quant=self.block_quant, - tp_size=layer.moe_parallel_config.tp_size, - with_lora_support=self.moe.is_lora_enabled, - is_act_and_mul=self.moe.is_act_and_mul, - ) - if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - if self.block_quant and self.weight_block_size != [128, 128]: - raise NotImplementedError( - "FlashInfer CUTLASS FP8 MoE backend only supports block " - "size [128, 128]." - ) - if layer.activation != "silu": - raise NotImplementedError( - "FlashInfer CUTLASS FP8 MoE backend only supports SiLU " - "activation function, but got {layer.activation}." - ) - dynamic_per_token = ( - not self.block_quant and self.quant_config.activation_scheme != "static" - ) - if dynamic_per_token and self.fp8_backend in [ - Fp8MoeBackend.FLASHINFER_TRTLLM, - Fp8MoeBackend.FLASHINFER_CUTLASS, - ]: - raise NotImplementedError( - "FlashInfer FP8 MoE backend does not support dynamic per token " - "activation quantization." + # Set weight key and activation key for kernel compatibility + if self.block_quant: + weight_key = kFp8Static128BlockSym + activation_key = kFp8Dynamic128Sym + else: + weight_key = kFp8StaticTensorSym + activation_key = ( + kFp8StaticTensorSym + if self.quant_config.activation_scheme == "static" + else kFp8DynamicTensorSym ) + # Select Fp8 MoE backend + self.fp8_backend, self.experts_cls = select_fp8_moe_backend( + config=self.moe, + weight_key=weight_key, + activation_key=activation_key, + allow_vllm_cutlass=False, + ) + + # Delay creation of the kernel until after process-weights. self.kernel: mk.FusedMoEModularKernel | None = None + @property + def topk_indices_dtype(self) -> torch.dtype | None: + if self.kernel is not None: + return self.kernel.prepare_finalize.topk_indices_dtype() + return None + def create_weights( self, layer: Module, @@ -842,14 +841,21 @@ class Fp8MoEMethod(FusedMoEMethodBase): replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale) replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale) - # Setup modular kernel for TP case. + # Setup modular kernel for TP case and naive DP/EP case. + # In non-naive DP/EP case, we will create a ModularKernelMethod. + # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel + # in both cases. self.moe_quant_config = self.get_fused_moe_quant_config(layer) - if self.moe_quant_config: + if self.moe_quant_config and ( + (not self.moe.moe_parallel_config.use_all2all_kernels) + or self.moe.moe_parallel_config.use_naive_all2all_kernels + ): + assert self.experts_cls is not None self.kernel, self.use_inplace = make_fp8_moe_kernel( - layer=layer, moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, + experts_cls=self.experts_cls, ) def process_weights_after_loading(self, layer: Module) -> None: @@ -904,13 +910,13 @@ class Fp8MoEMethod(FusedMoEMethodBase): self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - if self.fp8_backend in [ - Fp8MoeBackend.AITER, - Fp8MoeBackend.MARLIN, - Fp8MoeBackend.FLASHINFER_TRTLLM, - ]: + if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: return None elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: + # For no-EP case, don't use the MKM framework. + if not self.moe.moe_parallel_config.use_all2all_kernels: + return None + prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( self.moe, use_deepseek_fp8_block_scale=self.block_quant, @@ -924,73 +930,14 @@ class Fp8MoEMethod(FusedMoEMethodBase): prepare_finalize: FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: - from vllm.model_executor.layers.fused_moe import ( - BatchedDeepGemmExperts, - BatchedTritonExperts, - TritonExperts, - TritonOrDeepGemmExperts, - ) - - if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]: - raise NotImplementedError( - "Marlin and ROCm AITER are not supported with all2all yet." - ) - assert self.moe_quant_config is not None - - if ( - prepare_finalize.activation_format - == FusedMoEActivationFormat.BatchedExperts - ): - max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() - assert max_num_tokens_per_rank is not None - - experts_impl = ( - BatchedDeepGemmExperts - if self.fp8_backend == Fp8MoeBackend.DEEPGEMM - else BatchedTritonExperts - ) - logger.debug( - "%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s", - experts_impl.__name__, - self.__class__.__name__, - max_num_tokens_per_rank, - self.weight_block_size, - False, - ) - return experts_impl( - max_num_tokens=max_num_tokens_per_rank, - num_dispatchers=prepare_finalize.num_dispatchers(), - quant_config=self.moe_quant_config, - ) - elif self.moe.is_lora_enabled: - return TritonExperts(quant_config=self.moe_quant_config) - elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - # Select GEMM experts with block-scale when weights are block-quantized - experts = select_cutlass_fp8_gemm_impl( - self.moe, - self.moe_quant_config, - use_deepseek_fp8_block_scale=self.block_quant, - ) - logger.debug_once("Using %s", experts.__class__.__name__) - return experts - elif self.fp8_backend == Fp8MoeBackend.DEEPGEMM: - logger.debug( - "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s", - self.__class__.__name__, - self.weight_block_size, - False, - ) - return TritonOrDeepGemmExperts(self.moe_quant_config) - else: - assert self.fp8_backend == Fp8MoeBackend.TRITON - logger.debug( - "TritonExperts(%s): block_size=%s, per_act_token=%s", - self.__class__.__name__, - self.weight_block_size, - False, - ) - return TritonExperts(self.moe_quant_config) + assert self.experts_cls is not None + return make_fp8_moe_kernel_for_mkm( + moe_config=self.moe, + quant_config=self.moe_quant_config, + experts_cls=self.experts_cls, + prepare_finalize=prepare_finalize, + ) def get_fused_moe_quant_config( self, layer: torch.nn.Module @@ -1067,7 +1014,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): routed_scaling=layer.routed_scaling_factor, ) else: - result = apply_fi_trtllm_fp8_per_tensor_moe( + return apply_fi_trtllm_fp8_per_tensor_moe( layer=layer, hidden_states=x, router_logits=router_logits, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 00cd635b2..ec295005b 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -875,6 +875,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): return BatchedMarlinExperts( max_num_tokens=max_num_tokens_per_rank, num_dispatchers=prepare_finalize.num_dispatchers(), + moe_config=self.moe, quant_config=self.moe_quant_config, w13_g_idx=w13_g_idx, w2_g_idx=w2_g_idx, @@ -885,6 +886,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): else: # Standard Marlin experts for GPTQ return MarlinExperts( + moe_config=self.moe, quant_config=self.moe_quant_config, w13_g_idx=w13_g_idx, w2_g_idx=w2_g_idx, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 91dfa03b8..0de9cb88d 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -27,15 +27,16 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( Fp8MoeBackend, convert_to_fp8_moe_kernel_format, make_fp8_moe_kernel, + make_fp8_moe_kernel_for_mkm, make_fp8_moe_quant_config, select_fp8_moe_backend, ) from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( - FLASHINFER_NVFP4_MOE_BACKENDS, NvFp4MoeBackend, convert_to_nvfp4_moe_kernel_format, is_global_sf_supported_for_nvfp4_backend, make_nvfp4_moe_kernel, + make_nvfp4_moe_kernel_for_mkm, make_nvfp4_moe_quant_config, select_nvfp4_moe_backend, ) @@ -57,12 +58,10 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, flashinfer_trtllm_fp4_moe, flashinfer_trtllm_fp4_routed_moe, - select_nvfp4_gemm_impl, ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_fi_trtllm_fp8_per_tensor_moe, build_flashinfer_fp8_cutlass_moe_prepare_finalize, - select_cutlass_fp8_gemm_impl, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, @@ -84,6 +83,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8DynamicTokenSym, kFp8StaticTensorSym, kFp8StaticTokenSym, + kNvfp4Dynamic, + kNvfp4Static, swizzle_blockscale, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -728,14 +729,23 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): super().__init__(moe_config) self.quant_config = quant_config assert self.quant_config.is_checkpoint_fp8_serialized - self.fp8_backend = select_fp8_moe_backend( - block_quant=False, - tp_size=moe_config.moe_parallel_config.tp_size, - with_lora_support=self.moe.is_lora_enabled, - is_act_and_mul=self.moe.is_act_and_mul, + + # Select Fp8 MoE backend + self.fp8_backend, self.experts_cls = select_fp8_moe_backend( + config=self.moe, + weight_key=kFp8StaticTensorSym, + activation_key=kFp8StaticTensorSym, ) + + # Delay creation of the kernel until after process-weights. self.kernel: mk.FusedMoEModularKernel | None = None + @property + def topk_indices_dtype(self) -> torch.dtype | None: + if self.kernel is not None: + return self.kernel.prepare_finalize.topk_indices_dtype() + return None + def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, @@ -744,8 +754,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: return None elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - # TP case: avoid convert to ModularKernelMethod - to be refactored. - if self.moe.dp_size == 1: + # For no-EP case, don't use the MKM framework. + if not self.moe.moe_parallel_config.use_all2all_kernels: return None prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( @@ -762,12 +772,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: assert self.moe_quant_config is not None - experts = select_cutlass_fp8_gemm_impl( - self.moe, - self.moe_quant_config, + assert self.experts_cls is not None + return make_fp8_moe_kernel_for_mkm( + moe_config=self.moe, + quant_config=self.moe_quant_config, + experts_cls=self.experts_cls, + prepare_finalize=prepare_finalize, ) - logger.debug_once("Using %s", experts.__class__.__name__) - return experts def create_weights( self, @@ -876,14 +887,15 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): replace_parameter(layer, "w13_weight_scale", w13_scale) replace_parameter(layer, "w2_weight_scale", w2_scale) - # Setup modular kernel for TP case. + # Setup modular kernel. self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config: + assert self.experts_cls is not None self.kernel, self.use_inplace = make_fp8_moe_kernel( - layer=layer, moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, + experts_cls=self.experts_cls, ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: @@ -1335,32 +1347,35 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ) -> None: super().__init__(moe_config) self.quant_config = quant_config - self.nvfp4_backend = select_nvfp4_moe_backend() - # TODO: move this type of check into the oracle. - if not self.moe.is_act_and_mul and self.nvfp4_backend not in [ - NvFp4MoeBackend.FLASHINFER_CUTLASS, - NvFp4MoeBackend.MARLIN, - ]: - raise NotImplementedError( - "Non-gated activations are only supported by FlashInfer " - f"CUTLASS and Marlin NvFP4 MoE backends, not {self.nvfp4_backend}." - ) + # Select experts implementation. + self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend( + config=self.moe, + weight_key=kNvfp4Static, + activation_key=kNvfp4Dynamic, + ) + + # Delay creation of the kernel until after process-weights. + self.kernel: mk.FusedMoEModularKernel | None = None self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.nvfp4_backend ) - self.kernel: mk.FusedMoEModularKernel | None = None + + @property + def topk_indices_dtype(self) -> torch.dtype | None: + if self.kernel is not None: + return self.kernel.prepare_finalize.topk_indices_dtype() + return None def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - UNSUPPORTED = [NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.FLASHINFER_TRTLLM] - if self.nvfp4_backend in UNSUPPORTED: + if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: return None elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: - # TP case: avoid convert to ModularKernelMethod - to be refactored. - if self.moe.dp_size == 1: + # For no-EP case, don't use the MKM framework. + if not self.moe.moe_parallel_config.use_all2all_kernels: return None # For now, fp4 moe only works with the flashinfer dispatcher. prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( @@ -1377,13 +1392,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: assert self.moe_quant_config is not None - experts = select_nvfp4_gemm_impl( - self.moe, - self.moe_quant_config, - allow_flashinfer=self.nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS, + assert self.experts_cls is not None + return make_nvfp4_moe_kernel_for_mkm( + moe_config=self.moe, + quant_config=self.moe_quant_config, + experts_cls=self.experts_cls, + prepare_finalize=prepare_finalize, ) - logger.debug_once("Using %s", experts.__class__.__name__) - return experts def uses_weight_scale_2_pattern(self) -> bool: """ @@ -1554,13 +1569,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): replace_parameter(layer, "w2_weight_scale_2", w2_scale_2) replace_parameter(layer, "w2_input_scale", a2_scale) + # Setup modular kernel for TP case and naive DP/EP case. + # In non-naive DP/EP case, we will create a ModularKernelMethod. + # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel + # in both cases. self.moe_quant_config = self.get_fused_moe_quant_config(layer) - use_dp = self.moe.dp_size > 1 - if self.moe_quant_config is not None and not use_dp: + if self.moe_quant_config and ( + (not self.moe.moe_parallel_config.use_all2all_kernels) + or self.moe.moe_parallel_config.use_naive_all2all_kernels + ): + assert self.experts_cls is not None self.kernel = make_nvfp4_moe_kernel( - backend=self.nvfp4_backend, - quant_config=self.moe_quant_config, + moe_quant_config=self.moe_quant_config, moe_config=self.moe, + experts_cls=self.experts_cls, ) @property diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index ecd13e5c7..18dd3e40b 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -853,6 +853,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): max_num_tokens=max_num_tokens_per_rank, num_dispatchers=prepare_finalize.num_dispatchers(), quant_config=self.moe_quant_config, + moe_config=self.moe, ) else: raise NotImplementedError( @@ -875,11 +876,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): } return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs) elif self.mxfp4_backend == Mxfp4Backend.MARLIN: - return MarlinExperts(self.moe_quant_config) + return MarlinExperts(self.moe, self.moe_quant_config) elif self.mxfp4_backend == Mxfp4Backend.TRITON: if self.moe.is_lora_enabled: - return UnfusedOAITritonExperts(self.moe_quant_config) - return OAITritonExperts(self.moe_quant_config) + return UnfusedOAITritonExperts(self.moe, self.moe_quant_config) + return OAITritonExperts(self.moe, self.moe_quant_config) else: raise NotImplementedError( f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP" diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 272b13861..ea5884e0f 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -11,19 +11,16 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, - FusedMoEQuantConfig, + FusedMoEParallelConfig, RoutingMethodType, ) -from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( - FlashInferCuteDSLExperts, -) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts, -) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 create_flashinfer_prepare_finalize, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kNvfp4Dynamic, + kNvfp4Static, swizzle_blockscale, ) from vllm.platforms import current_platform @@ -47,6 +44,86 @@ __all__ = [ "build_flashinfer_fp4_cutlass_moe_prepare_finalize", ] +# +# Methods used by the oracle for kernel selection. +# + + +def _supports_current_device() -> bool: + """Supports only Blackwell-family GPUs.""" + p = current_platform + return p.is_cuda() and p.is_device_capability_family(100) + + +def _supports_no_act_and_mul() -> bool: + """Does not support non-gated MoE (i.e. Nemotron-Nano).""" + return False + + +def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, +) -> bool: + """Supports Nvfp4 quantization.""" + SUPPORTED_W_A = [ + (kNvfp4Static, kNvfp4Dynamic), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + +def _supports_activation(activation: str) -> bool: + """Supports silu activation only.""" + return activation in ["silu"] + + +def _supports_routing_method( + routing_method: RoutingMethodType, +) -> bool: + """Monolithic kernels need to express router support.""" + # NOTE(rob): potentially allow others here. This is a conservative list. + return routing_method in [ + RoutingMethodType.DeepSeekV3, + RoutingMethodType.Renormalize, + RoutingMethodType.RenormalizeNaive, + RoutingMethodType.Llama4, + ] + + +def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + """Supports EP.""" + return True + + +def is_supported_config_trtllm( + moe_config: FusedMoEConfig, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + activation_format: mk.FusedMoEActivationFormat, +) -> tuple[bool, str | None]: + """ + This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config + """ + + def _make_reason(reason: str) -> str: + return f"kernel does not support {reason}" + + if not _supports_current_device(): + return False, _make_reason("current device") + elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()): + return False, _make_reason("no act_and_mul MLP layer") + elif not _supports_activation(moe_config.activation): + return False, _make_reason(f"{moe_config.activation} activation") + elif not _supports_quant_scheme(weight_key, activation_key): + return False, _make_reason("quantization scheme") + elif not _supports_parallel_config(moe_config.moe_parallel_config): + return False, _make_reason("parallel config") + elif not _supports_routing_method(moe_config.routing_method): + return False, _make_reason("routing method") + elif activation_format != mk.FusedMoEActivationFormat.Standard: + return False, _make_reason("activation format") + + return True, None + def is_flashinfer_fp4_cutlass_moe_available() -> bool: """Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used.""" @@ -96,37 +173,6 @@ def build_flashinfer_fp4_cutlass_moe_prepare_finalize( ) -def select_nvfp4_gemm_impl( - moe: FusedMoEConfig, - moe_quant_config: FusedMoEQuantConfig, - allow_flashinfer: bool, -) -> mk.FusedMoEPermuteExpertsUnpermute: - """Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers""" - - if allow_flashinfer: - if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm": - return FlashInferCuteDSLExperts( - out_dtype=moe.in_dtype, - quant_config=moe_quant_config, - ) - elif envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput": - return FlashInferExperts( - out_dtype=moe.in_dtype, - quant_config=moe_quant_config, - ep_rank=moe.moe_parallel_config.ep_rank, - ep_size=moe.moe_parallel_config.ep_size, - tp_rank=moe.moe_parallel_config.tp_rank, - tp_size=moe.moe_parallel_config.tp_size, - use_dp=moe.moe_parallel_config.dp_size > 1, - ) - - # native cutlass experts currently don't support DP; TP case won't call this - raise ValueError( - "CutlassExpertsFp4 doesn't support DP. Use flashinfer CUTLASS " - "Fused MoE backend instead (set VLLM_USE_FLASHINFER_MOE_FP4=1)" - ) - - def prepare_static_weights_for_trtllm_fp4_moe( # args_dequant, # args, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 799854479..2cc17b12f 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -9,10 +9,6 @@ from vllm import envs from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, - FusedMoEQuantConfig, -) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts, ) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 create_flashinfer_prepare_finalize, @@ -203,33 +199,6 @@ def build_flashinfer_fp8_cutlass_moe_prepare_finalize( ) -def select_cutlass_fp8_gemm_impl( - moe: FusedMoEConfig | None, - quant_config: FusedMoEQuantConfig, - out_dtype: torch.dtype | None = None, - use_deepseek_fp8_block_scale: bool = False, -) -> mk.FusedMoEPermuteExpertsUnpermute: - """Return a GEMM *experts* implementation for fused-MoE layers""" - - if moe is not None: - return FlashInferExperts( - out_dtype=moe.in_dtype, - quant_config=quant_config, - ep_rank=moe.moe_parallel_config.ep_rank, - ep_size=moe.moe_parallel_config.ep_size, - tp_rank=moe.moe_parallel_config.tp_rank, - tp_size=moe.moe_parallel_config.tp_size, - use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, - ) - - assert out_dtype is not None, "If moe config is None, out_dtype must be passed" - return FlashInferExperts( - out_dtype=out_dtype, - quant_config=quant_config, - use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, - ) - - def get_flashinfer_moe_backend() -> FlashinferMoeBackend: backend_map = { "throughput": FlashinferMoeBackend.CUTLASS, diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 91fc8760b..bc7458444 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -48,6 +48,7 @@ class GroupShape(_GroupShape): # Aliases for common quantization group shapes PER_TENSOR: ClassVar["GroupShape"] PER_TOKEN: ClassVar["GroupShape"] + PER_CHANNEL: ClassVar["GroupShape"] def is_per_tensor(self) -> bool: return self.row == -1 and self.col == -1 @@ -55,12 +56,16 @@ class GroupShape(_GroupShape): def is_per_token(self) -> bool: return self.row == 1 and self.col == -1 + def is_per_channel(self) -> bool: + return self.row == -1 and self.col == 1 + def is_per_group(self) -> bool: return self.row == 1 and self.col >= 1 GroupShape.PER_TENSOR = GroupShape(-1, -1) GroupShape.PER_TOKEN = GroupShape(1, -1) +GroupShape.PER_CHANNEL = GroupShape(-1, 1) @dataclass(frozen=True) @@ -77,16 +82,12 @@ class ScaleDesc: group_shape: GroupShape def __str__(self): - group_shape = ( - "per_tensor" - if self.group_shape == GroupShape.PER_TENSOR - else ( - "per_token" - if self.group_shape == GroupShape.PER_TOKEN - else str(self.group_shape) - ) - ) - + d = { + GroupShape.PER_TENSOR: "per_tensor", + GroupShape.PER_TOKEN: "per_token", + GroupShape.PER_CHANNEL: "per_channel", + } + group_shape = d.get(self.group_shape, str(self.group_shape)) return ( f"{fx.graph.dtype_abbrs[self.dtype]}," f"{'static' if self.static else 'dynamic'},{group_shape}" @@ -126,15 +127,28 @@ kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True) kStaticTokenScale = ScaleDesc(torch.float32, True, GroupShape.PER_TOKEN) kFp8StaticTokenSym = QuantKey(FP8_DTYPE, kStaticTokenScale, symmetric=True) +kStaticChannelScale = ScaleDesc(torch.float32, True, GroupShape.PER_CHANNEL) +kFp8StaticChannelSym = QuantKey(FP8_DTYPE, kStaticChannelScale, symmetric=True) + kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN) kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True) -kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16)) -kNvfp4Quant = QuantKey(FP4_DTYPE, scale=kNvfp4GroupScale, scale2=kStaticTensorScale) +kNvfp4DynamicGroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16)) +kNvfp4Dynamic = QuantKey( + FP4_DTYPE, scale=kNvfp4DynamicGroupScale, scale2=kStaticTensorScale +) + +kNvfp4StaticGroupScale = ScaleDesc(FP8_DTYPE, True, GroupShape(1, 16)) +kNvfp4Static = QuantKey( + FP4_DTYPE, scale=kNvfp4StaticGroupScale, scale2=kStaticTensorScale +) kDynamic128Scale = ScaleDesc(torch.float32, False, GroupShape(1, 128)) kFp8Dynamic128Sym = QuantKey(FP8_DTYPE, kDynamic128Scale, symmetric=True) +kStatic128BlockScale = ScaleDesc(torch.float32, True, GroupShape(128, 128)) +kFp8Static128BlockSym = QuantKey(FP8_DTYPE, kStatic128BlockScale, symmetric=True) + kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64)) kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index f2f354604..8e49ccea5 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -43,7 +43,6 @@ from vllm.distributed import ( from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -172,7 +171,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module): enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, - routing_method_type=RoutingMethodType.Renormalize, ) self.gate = ReplicatedLinear( diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 0f73a7746..e244e6474 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -34,7 +34,6 @@ from vllm.model_executor.layers.fla.ops import ( fused_recurrent_gated_delta_rule, ) from vllm.model_executor.layers.fused_moe import SharedFusedMoE -from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.layernorm import ( GemmaRMSNorm as Qwen3NextRMSNorm, ) @@ -181,7 +180,6 @@ class Qwen3NextSparseMoeBlock(nn.Module): enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, - routing_method_type=RoutingMethodType.Renormalize, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index 212a725d4..21409de86 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -128,11 +128,15 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool: """ Return True if the input module/layer could be processed with DeepGEMM. """ + + # FIXME: this logic is brittle and incorrect - since we + # could use DeepGEMM with for than just Fp8LinearMethod block_size = get_mk_alignment_for_contiguous_layout()[0] if not ( isinstance(module, LinearBase) and isinstance(module.quant_method, Fp8LinearMethod) and module.quant_method.block_quant + and not module.quant_method.use_marlin ): return False diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 7a0aff80e..a76d01c5b 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -29,7 +29,7 @@ from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, - kNvfp4Quant, + kNvfp4Dynamic, ) from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability @@ -1184,7 +1184,7 @@ class FlashInferImpl(AttentionImpl): return ( self.support_trtllm_attn and self.kv_cache_dtype.startswith("fp8") - and quant_key in (kFp8StaticTensorSym, kNvfp4Quant) + and quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic) ) # FlashInfer requires attention sinks to be float32