diff --git a/tests/evals/gsm8k/configs/Qwen3-Next-80B-A3B-NVFP4-EP2.yaml b/tests/evals/gsm8k/configs/Qwen3-Next-80B-A3B-NVFP4-EP2.yaml index 673b473f8..7f2f096fd 100644 --- a/tests/evals/gsm8k/configs/Qwen3-Next-80B-A3B-NVFP4-EP2.yaml +++ b/tests/evals/gsm8k/configs/Qwen3-Next-80B-A3B-NVFP4-EP2.yaml @@ -8,5 +8,4 @@ server_args: >- --tensor-parallel-size 2 --enable-expert-parallel --speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":1}' -env: - VLLM_USE_FLASHINFER_MOE_FP4: "1" + --moe-backend=flashinfer_trtllm diff --git a/tests/evals/gsm8k/configs/Qwen3-Next-FP8-EP2.yaml b/tests/evals/gsm8k/configs/Qwen3-Next-FP8-EP2.yaml index 9fae32734..abcb784a7 100644 --- a/tests/evals/gsm8k/configs/Qwen3-Next-FP8-EP2.yaml +++ b/tests/evals/gsm8k/configs/Qwen3-Next-FP8-EP2.yaml @@ -7,5 +7,4 @@ server_args: >- --tensor-parallel-size 2 --enable-expert-parallel --async-scheduling -env: - VLLM_USE_FLASHINFER_MOE_FP8: "1" + --moe-backend=flashinfer_trtllm diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Llama-4-Scout-Fp8-ModelOpt-triton.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Llama-4-Scout-Fp8-ModelOpt-triton.yaml index 9e13797bb..fda02c367 100644 --- a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Llama-4-Scout-Fp8-ModelOpt-triton.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Llama-4-Scout-Fp8-ModelOpt-triton.yaml @@ -2,7 +2,6 @@ model_name: "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" accuracy_threshold: 0.92 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel" +server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel --moe-backend=triton" env: - VLLM_USE_FLASHINFER_MOE_FP8: "0" VLLM_USE_DEEP_GEMM: "0" diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-CT-fi-cutedsl-deepep-ll.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-CT-fi-cutedsl-deepep-ll.yaml index 1328fdedf..6624cea1e 100644 --- a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-CT-fi-cutedsl-deepep-ll.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-CT-fi-cutedsl-deepep-ll.yaml @@ -2,7 +2,4 @@ model_name: "RedHatAI/Qwen3-30B-A3B-NVFP4" accuracy_threshold: 0.88 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel --all2all-backend deepep_low_latency" -env: - VLLM_USE_FLASHINFER_MOE_FP4: "1" - VLLM_FLASHINFER_MOE_BACKEND: "masked_gemm" +server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel --all2all-backend deepep_low_latency --moe-backend=flashinfer_cutedsl" diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-CT-fi-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-CT-fi-cutlass.yaml index 53fd62bac..90265a12a 100644 --- a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-CT-fi-cutlass.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-CT-fi-cutlass.yaml @@ -2,7 +2,4 @@ model_name: "RedHatAI/Qwen3-30B-A3B-NVFP4" accuracy_threshold: 0.88 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel" -env: - VLLM_USE_FLASHINFER_MOE_FP4: "1" - VLLM_FLASHINFER_MOE_BACKEND: "throughput" +server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel --moe-backend=flashinfer_cutlass" diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutedsl-deepep-ll.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutedsl-deepep-ll.yaml index 87fac0e70..f2d4588e3 100644 --- a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutedsl-deepep-ll.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutedsl-deepep-ll.yaml @@ -2,7 +2,4 @@ model_name: "nvidia/Qwen3-30B-A3B-NVFP4" accuracy_threshold: 0.88 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel --all2all-backend deepep_low_latency" -env: - VLLM_USE_FLASHINFER_MOE_FP4: "1" - VLLM_FLASHINFER_MOE_BACKEND: "masked_gemm" +server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel --all2all-backend deepep_low_latency --moe-backend=flashinfer_cutedsl" diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutlass.yaml index 44f8700e4..49be54e26 100644 --- a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutlass.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutlass.yaml @@ -2,7 +2,4 @@ model_name: "nvidia/Qwen3-30B-A3B-NVFP4" accuracy_threshold: 0.88 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel" -env: - VLLM_USE_FLASHINFER_MOE_FP4: "1" - VLLM_FLASHINFER_MOE_BACKEND: "throughput" +server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel --moe-backend=flashinfer_cutlass" diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-ModelOpt-fi-trtllm.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-ModelOpt-fi-trtllm.yaml index 91a220c4f..23d29e06f 100644 --- a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-ModelOpt-fi-trtllm.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-ModelOpt-fi-trtllm.yaml @@ -2,7 +2,4 @@ model_name: "nvidia/Qwen3-30B-A3B-NVFP4" accuracy_threshold: 0.88 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel" -env: - VLLM_USE_FLASHINFER_MOE_FP4: "1" - VLLM_FLASHINFER_MOE_BACKEND: "latency" +server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel --moe-backend=flashinfer_trtllm" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-BF16-fi-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-BF16-fi-cutlass.yaml index 5416d9232..e19500fd3 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-BF16-fi-cutlass.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-BF16-fi-cutlass.yaml @@ -2,8 +2,4 @@ model_name: "meta-llama/Llama-4-Scout-17B-16E-Instruct" accuracy_threshold: 0.92 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --enable-expert-parallel" -env: - VLLM_USE_FLASHINFER_MOE_FP16: "1" - VLLM_FLASHINFER_MOE_BACKEND: "throughput" - +server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --enable-expert-parallel --moe-backend=flashinfer_cutlass" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-ModelOpt-fi-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-ModelOpt-fi-cutlass.yaml index 4c9a01274..217ee5e60 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-ModelOpt-fi-cutlass.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-ModelOpt-fi-cutlass.yaml @@ -2,7 +2,4 @@ model_name: "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" accuracy_threshold: 0.92 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" -env: - VLLM_USE_FLASHINFER_MOE_FP8: "1" - VLLM_FLASHINFER_MOE_BACKEND: "throughput" +server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=flashinfer_cutlass" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-ModelOpt-fi-trtllm.yaml b/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-ModelOpt-fi-trtllm.yaml index 17f067215..7e9300d9f 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-ModelOpt-fi-trtllm.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-ModelOpt-fi-trtllm.yaml @@ -2,7 +2,4 @@ model_name: "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" accuracy_threshold: 0.92 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" -env: - VLLM_USE_FLASHINFER_MOE_FP8: "1" - VLLM_FLASHINFER_MOE_BACKEND: "latency" +server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=flashinfer_trtllm" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-ModelOpt-triton.yaml b/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-ModelOpt-triton.yaml index ae6bf6755..87f960afe 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-ModelOpt-triton.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-ModelOpt-triton.yaml @@ -2,6 +2,4 @@ model_name: "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" accuracy_threshold: 0.92 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" -env: - VLLM_USE_FLASHINFER_MOE_FP8: "0" +server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=triton" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Mixtral-8x7B-BF16-fi-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Mixtral-8x7B-BF16-fi-cutlass.yaml index cc8df6292..1c5865974 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Mixtral-8x7B-BF16-fi-cutlass.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Mixtral-8x7B-BF16-fi-cutlass.yaml @@ -2,7 +2,4 @@ model_name: "mistralai/Mixtral-8x7B-v0.1" accuracy_threshold: 0.58 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --enable-expert-parallel" -env: - VLLM_USE_FLASHINFER_MOE_FP16: "1" - VLLM_FLASHINFER_MOE_BACKEND: "throughput" +server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --enable-expert-parallel --moe-backend=flashinfer_cutlass" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Mixtral-8x7B-Fp8-AutoFp8-fi-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Mixtral-8x7B-Fp8-AutoFp8-fi-cutlass.yaml index b9c6a1997..f836a5038 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Mixtral-8x7B-Fp8-AutoFp8-fi-cutlass.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Mixtral-8x7B-Fp8-AutoFp8-fi-cutlass.yaml @@ -3,7 +3,4 @@ # accuracy_threshold: 0.62 # num_questions: 1319 # num_fewshot: 5 -# server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" -# env: -# VLLM_USE_FLASHINFER_MOE_FP8: "1" -# VLLM_FLASHINFER_MOE_BACKEND: "throughput" +# server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=flashinfer_cutlass" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Nemotron-Nano-30B-Fp8-ModelOpt-fi-trtllm.yaml b/tests/evals/gsm8k/configs/moe-refactor/Nemotron-Nano-30B-Fp8-ModelOpt-fi-trtllm.yaml index 570569def..a06c93dcc 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Nemotron-Nano-30B-Fp8-ModelOpt-fi-trtllm.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Nemotron-Nano-30B-Fp8-ModelOpt-fi-trtllm.yaml @@ -2,7 +2,4 @@ model_name: "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8" accuracy_threshold: 0.29 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" -env: - VLLM_USE_FLASHINFER_MOE_FP8: "1" - VLLM_FLASHINFER_MOE_BACKEND: "latency" +server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=flashinfer_trtllm" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Nemotron-Nano-30B-NvFp4-ModelOpt-fi-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Nemotron-Nano-30B-NvFp4-ModelOpt-fi-cutlass.yaml index d802ac3f3..b5a8676d7 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Nemotron-Nano-30B-NvFp4-ModelOpt-fi-cutlass.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Nemotron-Nano-30B-NvFp4-ModelOpt-fi-cutlass.yaml @@ -2,7 +2,4 @@ model_name: "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4" accuracy_threshold: 0.29 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" -env: - VLLM_USE_FLASHINFER_MOE_FP4: "1" - VLLM_FLASHINFER_MOE_BACKEND: "throughput" +server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=flashinfer_cutlass" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-BF16-fi-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-BF16-fi-cutlass.yaml index b15126a45..92b9c071e 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-BF16-fi-cutlass.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-BF16-fi-cutlass.yaml @@ -2,6 +2,4 @@ model_name: "Qwen/Qwen3-30B-A3B" accuracy_threshold: 0.88 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --enable-expert-parallel" -env: - VLLM_USE_FLASHINFER_MOE_FP16: "1" +server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --enable-expert-parallel --moe-backend=flashinfer_cutlass" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-cutlass.yaml index 74820cd28..b392f9245 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-cutlass.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-cutlass.yaml @@ -2,7 +2,4 @@ model_name: "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8" accuracy_threshold: 0.88 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" -env: - VLLM_USE_FLASHINFER_MOE_FP8: "1" - VLLM_FLASHINFER_MOE_BACKEND: "throughput" +server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=flashinfer_cutlass" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-trtllm.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-trtllm.yaml index d745c9b5b..4fd2f8d26 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-trtllm.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-fi-trtllm.yaml @@ -2,7 +2,4 @@ model_name: "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8" accuracy_threshold: 0.88 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" -env: - VLLM_USE_FLASHINFER_MOE_FP8: "1" - VLLM_FLASHINFER_MOE_BACKEND: "latency" +server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=flashinfer_trtllm" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-triton.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-triton.yaml index 1b2d72160..0dd401d2d 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-triton.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-AutoFp8-triton.yaml @@ -2,7 +2,6 @@ model_name: "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8" accuracy_threshold: 0.88 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" +server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=triton" env: - VLLM_USE_FLASHINFER_MOE_FP8: "0" VLLM_USE_DEEP_GEMM: "0" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-fi-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-fi-cutlass.yaml index 48ab58c46..fb52d3600 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-fi-cutlass.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-fi-cutlass.yaml @@ -2,7 +2,4 @@ model_name: "RedHatAI/Qwen3-30B-A3B-FP8-block" accuracy_threshold: 0.85 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" -env: - VLLM_USE_FLASHINFER_MOE_FP8: "1" - VLLM_FLASHINFER_MOE_BACKEND: "throughput" +server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=flashinfer_cutlass" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-triton.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-triton.yaml index 3e30d4d15..5bd907c05 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-triton.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-triton.yaml @@ -2,7 +2,6 @@ model_name: "RedHatAI/Qwen3-30B-A3B-FP8-block" accuracy_threshold: 0.85 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" +server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=triton" env: - VLLM_USE_FLASHINFER_MOE_FP8: "0" VLLM_USE_DEEP_GEMM: "0" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-CT-fi-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-CT-fi-cutlass.yaml index 6edacc329..3c1b20c24 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-CT-fi-cutlass.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-CT-fi-cutlass.yaml @@ -2,7 +2,4 @@ model_name: "RedHatAI/Qwen3-30B-A3B-NVFP4" accuracy_threshold: 0.88 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" -env: - VLLM_USE_FLASHINFER_MOE_FP4: "1" - VLLM_FLASHINFER_MOE_BACKEND: "throughput" +server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=flashinfer_cutlass" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-CT-fi-trtllm.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-CT-fi-trtllm.yaml index 8e0b155fa..094ec92f1 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-CT-fi-trtllm.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-CT-fi-trtllm.yaml @@ -2,7 +2,4 @@ model_name: "RedHatAI/Qwen3-30B-A3B-NVFP4" accuracy_threshold: 0.88 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" -env: - VLLM_USE_FLASHINFER_MOE_FP4: "1" - VLLM_FLASHINFER_MOE_BACKEND: "latency" +server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=flashinfer_trtllm" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-CT-vllm-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-CT-vllm-cutlass.yaml index 0d7884928..c38bc162e 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-CT-vllm-cutlass.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-CT-vllm-cutlass.yaml @@ -2,6 +2,4 @@ model_name: "RedHatAI/Qwen3-30B-A3B-NVFP4" accuracy_threshold: 0.88 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" -env: - VLLM_USE_FLASHINFER_MOE_FP4: "0" +server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=cutlass" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutlass.yaml index 09e76e21a..0ebc68ad3 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutlass.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutlass.yaml @@ -2,7 +2,4 @@ model_name: "nvidia/Qwen3-30B-A3B-NVFP4" accuracy_threshold: 0.88 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" -env: - VLLM_USE_FLASHINFER_MOE_FP4: "1" - VLLM_FLASHINFER_MOE_BACKEND: "throughput" +server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=flashinfer_cutlass" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-fi-trtllm.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-fi-trtllm.yaml index a98afafbc..491b3c82f 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-fi-trtllm.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-fi-trtllm.yaml @@ -2,7 +2,4 @@ model_name: "nvidia/Qwen3-30B-A3B-NVFP4" accuracy_threshold: 0.88 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" -env: - VLLM_USE_FLASHINFER_MOE_FP4: "1" - VLLM_FLASHINFER_MOE_BACKEND: "latency" +server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=flashinfer_trtllm" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-vllm-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-vllm-cutlass.yaml index a340b6fda..242c6ff52 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-vllm-cutlass.yaml +++ b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-NvFp4-ModelOpt-vllm-cutlass.yaml @@ -2,6 +2,4 @@ model_name: "nvidia/Qwen3-30B-A3B-NVFP4" accuracy_threshold: 0.88 num_questions: 1319 num_fewshot: 5 -server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" -env: - VLLM_USE_FLASHINFER_MOE_FP4: "0" +server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=cutlass" diff --git a/tests/quantization/test_blackwell_moe.py b/tests/quantization/test_blackwell_moe.py index 07da2b454..3a44ff423 100644 --- a/tests/quantization/test_blackwell_moe.py +++ b/tests/quantization/test_blackwell_moe.py @@ -85,34 +85,34 @@ def can_initialize( ) ) def test_llama4_fp8_tensor_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") - monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") can_initialize( - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", hf_overrides=HF_OVERRIDE_MM + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + hf_overrides=HF_OVERRIDE_MM, + extra_args=["--moe-backend=flashinfer_cutlass"], ) def test_llama4_fp8_tensor_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") - monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") can_initialize( - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", hf_overrides=HF_OVERRIDE_MM + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + hf_overrides=HF_OVERRIDE_MM, + extra_args=["--moe-backend=flashinfer_trtllm"], ) def test_llama4_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") - monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") can_initialize( - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", hf_overrides=HF_OVERRIDE_MM + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + hf_overrides=HF_OVERRIDE_MM, + extra_args=["--moe-backend=flashinfer_cutlass"], ) def test_llama4_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") - monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") can_initialize( - "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", hf_overrides=HF_OVERRIDE_MM + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + hf_overrides=HF_OVERRIDE_MM, + extra_args=["--moe-backend=flashinfer_trtllm"], ) @@ -120,8 +120,11 @@ def test_llama4_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1") - can_initialize("deepseek-ai/DeepSeek-V3.1", hf_overrides=HF_OVERRIDE_TEXT) + can_initialize( + "deepseek-ai/DeepSeek-V3.1", + hf_overrides=HF_OVERRIDE_TEXT, + extra_args=["--moe-backend=deep_gemm"], + ) @pytest.mark.skip( @@ -131,27 +134,35 @@ def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch): ) ) def test_deepseek_fp8_block_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") - monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") - can_initialize("deepseek-ai/DeepSeek-V3.1", hf_overrides=HF_OVERRIDE_TEXT) + can_initialize( + "deepseek-ai/DeepSeek-V3.1", + hf_overrides=HF_OVERRIDE_TEXT, + extra_args=["--moe-backend=flashinfer_cutlass"], + ) def test_deepseek_fp8_block_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") - monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") - can_initialize("deepseek-ai/DeepSeek-V3.1", hf_overrides=HF_OVERRIDE_TEXT) + can_initialize( + "deepseek-ai/DeepSeek-V3.1", + hf_overrides=HF_OVERRIDE_TEXT, + extra_args=["--moe-backend=flashinfer_trtllm"], + ) def test_deepseek_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") - monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") - can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", hf_overrides=HF_OVERRIDE_TEXT) + can_initialize( + "nvidia/DeepSeek-R1-0528-FP4-v2", + hf_overrides=HF_OVERRIDE_TEXT, + extra_args=["--moe-backend=flashinfer_cutlass"], + ) def test_deepseek_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") - monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") - can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", hf_overrides=HF_OVERRIDE_TEXT) + can_initialize( + "nvidia/DeepSeek-R1-0528-FP4-v2", + hf_overrides=HF_OVERRIDE_TEXT, + extra_args=["--moe-backend=flashinfer_trtllm"], + ) ## GPT-OSS ## @@ -184,5 +195,8 @@ def test_gptoss_eager(monkeypatch: pytest.MonkeyPatch): def test_qwen3_next_bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1") - can_initialize("Qwen/Qwen3-Next-80B-A3B-Instruct", hf_overrides=HF_OVERRIDE_TEXT) + can_initialize( + "Qwen/Qwen3-Next-80B-A3B-Instruct", + hf_overrides=HF_OVERRIDE_TEXT, + extra_args=["--moe-backend=flashinfer_trtllm"], + ) diff --git a/vllm/config/kernel.py b/vllm/config/kernel.py index 0730e4649..3c08ef882 100644 --- a/vllm/config/kernel.py +++ b/vllm/config/kernel.py @@ -2,13 +2,25 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable -from typing import Any +from typing import Any, Literal from pydantic import Field, field_validator from vllm.config.utils import config from vllm.utils.hashing import safe_hash +MoEBackend = Literal[ + "auto", + "triton", + "deep_gemm", + "cutlass", + "flashinfer_trtllm", + "flashinfer_cutlass", + "flashinfer_cutedsl", + "marlin", + "aiter", +] + @config class KernelConfig: @@ -17,6 +29,26 @@ class KernelConfig: enable_flashinfer_autotune: bool = Field(default=None) """If True, run FlashInfer autotuning during kernel warmup.""" + moe_backend: MoEBackend = "auto" + """Backend for MoE expert computation kernels. Available options: + + - "auto": Automatically select the best backend based on model and hardware\n + - "triton": Use Triton-based fused MoE kernels\n + - "deep_gemm": Use DeepGEMM kernels (FP8 block-quantized only)\n + - "cutlass": Use vLLM CUTLASS kernels\n + - "flashinfer_trtllm": Use FlashInfer with TRTLLM-GEN kernels\n + - "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels\n + - "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only)\n + - "marlin": Use Marlin kernels (weight-only quantization)\n + - "aiter": Use AMD AITer kernels (ROCm only)""" + + @field_validator("moe_backend", mode="before") + @classmethod + def _normalize_moe_backend(cls, value: Any) -> Any: + if isinstance(value, str): + return value.lower().replace("-", "_") + return value + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ca76454d6..036178887 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -70,6 +70,7 @@ from vllm.config.cache import ( PrefixCachingHashAlgo, ) from vllm.config.device import Device +from vllm.config.kernel import MoEBackend from vllm.config.lora import MaxLoRARanks from vllm.config.model import ( ConvertOption, @@ -416,6 +417,7 @@ class EngineArgs: data_parallel_external_lb: bool = False data_parallel_backend: DataParallelBackend = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel + moe_backend: MoEBackend = KernelConfig.moe_backend all2all_backend: All2AllBackend = ParallelConfig.all2all_backend enable_dbo: bool = ParallelConfig.enable_dbo ubatch_size: int = ParallelConfig.ubatch_size @@ -1227,6 +1229,9 @@ class EngineArgs: "--enable-flashinfer-autotune", **kernel_kwargs["enable_flashinfer_autotune"], ) + moe_backend_kwargs = kernel_kwargs["moe_backend"] + moe_backend_kwargs["type"] = lambda s: s.lower().replace("-", "_") + kernel_group.add_argument("--moe-backend", **moe_backend_kwargs) # vLLM arguments vllm_kwargs = get_kwargs(VllmConfig) @@ -1817,6 +1822,8 @@ class EngineArgs: "are mutually exclusive" ) kernel_config.enable_flashinfer_autotune = self.enable_flashinfer_autotune + if self.moe_backend != "auto": + kernel_config.moe_backend = self.moe_backend load_config = self.create_load_config() diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 22e71d391..87e1e244b 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -1066,7 +1066,6 @@ class FusedMoEParallelConfig: - Comment: There are 2 engine instances and the experts are split between the 4 devices. """ - use_ep = ( dp_size_ * pcp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel @@ -1155,6 +1154,7 @@ class FusedMoEConfig: # Defaults to in_dtype if not specified. router_logits_dtype: torch.dtype | None = None + moe_backend: str = "auto" max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE has_bias: bool = False is_act_and_mul: bool = True diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index f5a3da438..a4cee76f7 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -198,7 +198,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): x = x[0].permute(2, 0, 1) num_experts, max_tokens, hidden_dim_by_2 = x.shape hidden_dim = hidden_dim_by_2 * 2 - assert envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm" logger.info_once( "Quantization is fused with DeepEP nvfp4 dispatch for " "FlashInfer CUTEDSL as VLLM_DEEPEPLL_NVFP4_DISPATCH==1" diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6cb3dae26..679b79ce9 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -550,6 +550,7 @@ class FusedMoE(CustomOp): num_logical_experts=self.logical_num_experts, moe_parallel_config=self.moe_parallel_config, in_dtype=moe_in_dtype, + moe_backend=vllm_config.kernel_config.moe_backend, router_logits_dtype=router_logits_dtype, max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, has_bias=has_bias, diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index 243220989..6f961df07 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -7,6 +7,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import envs from vllm._aiter_ops import rocm_aiter_ops +from vllm.config.kernel import MoEBackend from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.all2all_utils import ( maybe_make_prepare_finalize, @@ -180,6 +181,25 @@ def backend_to_kernel_cls( raise ValueError(f"Unknown FP8 MoE backend: {backend.value}") +def map_fp8_backend(runner_backend: MoEBackend) -> Fp8MoeBackend: + """Map user's MoEBackend to Fp8MoeBackend.""" + mapping = { + "triton": Fp8MoeBackend.TRITON, + "deep_gemm": Fp8MoeBackend.DEEPGEMM, + "cutlass": Fp8MoeBackend.VLLM_CUTLASS, + "flashinfer_trtllm": Fp8MoeBackend.FLASHINFER_TRTLLM, + "flashinfer_cutlass": Fp8MoeBackend.FLASHINFER_CUTLASS, + "marlin": Fp8MoeBackend.MARLIN, + "aiter": Fp8MoeBackend.AITER, + } + if backend := mapping.get(runner_backend): + return backend + raise ValueError( + f"moe_backend='{runner_backend}' is not supported for FP8 MoE. " + f"Expected one of {list(mapping.keys())}." + ) + + def select_fp8_moe_backend( config: FusedMoEConfig, weight_key: QuantKey | None, @@ -242,6 +262,45 @@ def select_fp8_moe_backend( return backend, k_cls raise ValueError(_make_log_unsupported(backend, reason)) + # Handle explicit moe_backend from user. + runner_backend = config.moe_backend + if runner_backend != "auto": + requested_backend = map_fp8_backend(runner_backend) + # For batched activation format, use batched variants if available. + if activation_format == mk.FusedMoEActivationFormat.BatchedExperts: + if requested_backend == Fp8MoeBackend.DEEPGEMM: + requested_backend = Fp8MoeBackend.BATCHED_DEEPGEMM + elif requested_backend == Fp8MoeBackend.TRITON: + requested_backend = Fp8MoeBackend.BATCHED_TRITON + elif requested_backend == Fp8MoeBackend.VLLM_CUTLASS: + requested_backend = Fp8MoeBackend.BATCHED_VLLM_CUTLASS + + if ( + requested_backend + in [ + Fp8MoeBackend.VLLM_CUTLASS, + Fp8MoeBackend.BATCHED_VLLM_CUTLASS, + ] + and not allow_vllm_cutlass + ): + raise ValueError( + "vLLM CUTLASS FP8 MoE backend is disabled for this configuration." + ) + + # Handle FLASHINFER_TRTLLM specially (no kernel class). + if requested_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: + supported, reason = is_supported_config_trtllm_fp8( + config, weight_key, activation_key, activation_format + ) + if supported: + logger.info_once(_make_log_backend(requested_backend)) + return requested_backend, None + raise ValueError(_make_log_unsupported(requested_backend, reason)) + + return _return_or_raise( + requested_backend, config, weight_key, activation_key, activation_format + ) + # Handle explicit FlashInfer FP8 configuration. if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP8"): if not envs.VLLM_USE_FLASHINFER_MOE_FP8: diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index dc3ac61ad..ee7db88cc 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -6,6 +6,7 @@ import torch import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.config.kernel import MoEBackend from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.all2all_utils import ( maybe_make_prepare_finalize, @@ -103,6 +104,23 @@ def backend_to_kernel_cls( raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}") +def map_nvfp4_backend(runner_backend: MoEBackend) -> NvFp4MoeBackend: + """Map user's MoEBackend to NvFp4MoeBackend.""" + mapping = { + "cutlass": NvFp4MoeBackend.VLLM_CUTLASS, + "flashinfer_trtllm": NvFp4MoeBackend.FLASHINFER_TRTLLM, + "flashinfer_cutlass": NvFp4MoeBackend.FLASHINFER_CUTLASS, + "flashinfer_cutedsl": NvFp4MoeBackend.FLASHINFER_CUTEDSL, + "marlin": NvFp4MoeBackend.MARLIN, + } + if backend := mapping.get(runner_backend): + return backend + raise ValueError( + f"moe_backend='{runner_backend}' is not supported for NvFP4 MoE. " + f"Expected one of {list(mapping.keys())}." + ) + + def select_nvfp4_moe_backend( config: FusedMoEConfig, weight_key: QuantKey | None, @@ -170,6 +188,23 @@ def select_nvfp4_moe_backend( return backend, k_cls raise ValueError(_make_log_unsupported(backend, reason)) + # Handle explicit moe_backend from user. + runner_backend = config.moe_backend + if runner_backend != "auto": + requested_backend = map_nvfp4_backend(runner_backend) + if requested_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: + supported, reason = is_supported_config_trtllm( + config, weight_key, activation_key, activation_format + ) + if supported: + logger.info_once(_make_log_backend(requested_backend)) + return requested_backend, None + raise ValueError(_make_log_unsupported(requested_backend, reason)) + + return _return_or_raise( + requested_backend, config, weight_key, activation_key, activation_format + ) + if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"): if not envs.VLLM_USE_FLASHINFER_MOE_FP4: # If the user rejects FlashInfer remove those backends. diff --git a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py index 61aaa6927..1c582bcdc 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py +++ b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py @@ -9,6 +9,7 @@ from torch.nn import Module import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._aiter_ops import rocm_aiter_ops +from vllm.config.kernel import MoEBackend from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, @@ -51,6 +52,22 @@ UNSUPPORTED_BACKEND = [ ] +def map_unquantized_backend(runner_backend: MoEBackend) -> UnquantizedMoeBackend: + """Map user's MoEBackend to UnquantizedMoeBackend.""" + mapping = { + "triton": UnquantizedMoeBackend.TRITON, + "flashinfer_trtllm": UnquantizedMoeBackend.FLASHINFER_TRTLLM, + "flashinfer_cutlass": UnquantizedMoeBackend.FLASHINFER_CUTLASS, + "aiter": UnquantizedMoeBackend.AITER, + } + if backend := mapping.get(runner_backend): + return backend + raise ValueError( + f"moe_backend='{runner_backend}' is not supported for unquantized MoE. " + f"Expected one of {list(mapping.keys())}." + ) + + def select_unquantized_moe_backend( moe_config: FusedMoEConfig, use_ep: bool, @@ -64,8 +81,6 @@ def select_unquantized_moe_backend( def _make_log_backend(backend: UnquantizedMoeBackend): return f"Using {backend.value} backend for Unquantized MoE" - rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() - activation_format = ( mk.FusedMoEActivationFormat.BatchedExperts if moe_config.moe_parallel_config.use_batched_activation_format @@ -77,20 +92,49 @@ def select_unquantized_moe_backend( moe_config=moe_config, activation_format=activation_format, ) - flashinfer_trtllm_moe_enabled = ( - has_flashinfer() - and envs.VLLM_USE_FLASHINFER_MOE_FP16 - and trtllm_supported - and envs.VLLM_FLASHINFER_MOE_BACKEND == "latency" - ) + flashinfer_trtllm_available = has_flashinfer() and trtllm_supported # FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS - flashinfer_cutlass_moe_enabled = ( + flashinfer_cutlass_available = ( has_flashinfer_cutlass_fused_moe() - and envs.VLLM_USE_FLASHINFER_MOE_FP16 and use_ep and (not use_dp) and current_platform.has_device_capability(90) ) + flashinfer_trtllm_moe_enabled = ( + flashinfer_trtllm_available + and envs.VLLM_USE_FLASHINFER_MOE_FP16 + and envs.VLLM_FLASHINFER_MOE_BACKEND == "latency" + ) + flashinfer_cutlass_moe_enabled = ( + flashinfer_cutlass_available and envs.VLLM_USE_FLASHINFER_MOE_FP16 + ) + rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() + + # Handle explicit moe_backend from user. + runner_backend = moe_config.moe_backend + if runner_backend != "auto": + requested_backend = map_unquantized_backend(runner_backend) + if requested_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM: + if not flashinfer_trtllm_available: + raise ValueError( + "FlashInfer TRTLLM MoE backend is not available for this " + "configuration." + ) + elif requested_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS: + if not flashinfer_cutlass_available: + raise ValueError( + "FlashInfer CUTLASS MoE backend is not available for this " + "configuration." + ) + elif requested_backend == UnquantizedMoeBackend.AITER and not ( + current_platform.is_rocm() and rocm_aiter_moe_enabled + ): + raise ValueError( + "ROCm AITer MoE backend is not available for this configuration." + ) + logger.info_once(_make_log_backend(requested_backend), scope="local") + return requested_backend + if current_platform.is_rocm(): if rocm_aiter_moe_enabled: backend = UnquantizedMoeBackend.AITER