[MoE Refactor] MXFP4 Cutlass Experts to MK (#34542)

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
Yongye Zhu
2026-02-25 17:32:39 -08:00
committed by GitHub
parent cbf8f7028c
commit 1976356ee6
19 changed files with 454 additions and 169 deletions

View File

@@ -73,3 +73,29 @@ steps:
num_devices: 2
commands:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=evals/gsm8k/configs/moe-refactor-dp-ep/config-b200.txt
- label: GPQA Eval (GPT-OSS) (H100)
timeout_in_minutes: 120
device: h100
optional: true
num_devices: 2
source_file_dependencies:
- csrc/
- vllm/model_executor/layers/quantization
- tests/evals/gpt_oss/
commands:
- uv pip install --system 'gpt-oss[eval]==0.0.5'
- pytest -s -v evals/gpt_oss/test_gpqa_correctness.py --config-list-file=configs/models-h100.txt
- label: GPQA Eval (GPT-OSS) (B200)
timeout_in_minutes: 120
device: b200
optional: true
num_devices: 2
source_file_dependencies:
- csrc/
- vllm/model_executor/layers/quantization
- tests/evals/gpt_oss/
commands:
- uv pip install --system 'gpt-oss[eval]==0.0.5'
- pytest -s -v evals/gpt_oss/test_gpqa_correctness.py --config-list-file=configs/models-b200.txt

View File

@@ -153,33 +153,6 @@ steps:
- pytest -v -s transformers_utils
- pytest -v -s config
- label: GPT-OSS Eval (H100)
timeout_in_minutes: 60
working_dir: "/vllm-workspace/"
device: h100
optional: true
source_file_dependencies:
- tests/evals/gpt_oss
- vllm/model_executor/models/gpt_oss.py
- vllm/model_executor/layers/quantization/mxfp4.py
commands:
- uv pip install --system 'gpt-oss[eval]==0.0.5'
- pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58
- label: GPT-OSS Eval (B200)
timeout_in_minutes: 60
working_dir: "/vllm-workspace/"
device: b200
optional: true
source_file_dependencies:
- tests/evals/gpt_oss
- vllm/model_executor/models/gpt_oss.py
- vllm/model_executor/layers/quantization/mxfp4.py
- vllm/v1/attention/backends/flashinfer.py
commands:
- uv pip install --system 'gpt-oss[eval]==0.0.5'
- pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58
- label: Batch Invariance (H100)
timeout_in_minutes: 25
device: h100

View File

@@ -0,0 +1,49 @@
# GPQA Evaluation using GPT-OSS
This directory contains GPQA evaluation tests using the GPT-OSS evaluation package and vLLM server.
## Usage
### Run tests with pytest (like buildkite)
```bash
# H200
pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py \
--config-list-file=configs/models-h200.txt
# B200
pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py \
--config-list-file=configs/models-b200.txt
```
## Configuration Format
Model configs in `configs/` directory use this YAML format:
```yaml
model_name: "openai/gpt-oss-20b"
metric_threshold: 0.568 # Minimum expected accuracy
reasoning_effort: "low" # Reasoning effort level (default: "low")
server_args: "--tensor-parallel-size 2" # Server arguments
startup_max_wait_seconds: 1800 # Max wait for server startup (default: 1800)
env: # Environment variables (optional)
SOME_VAR: "value"
```
The `server_args` field accepts any arguments that can be passed to `vllm serve`.
The `env` field accepts a dictionary of environment variables to set for the server process.
## Adding New Models
1. Create a new YAML config file in the `configs/` directory
2. Add the filename to the appropriate `models-*.txt` file
## Tiktoken Encoding Files
The tiktoken encoding files required by the vLLM server are automatically downloaded from OpenAI's public blob storage on first run:
- `cl100k_base.tiktoken`
- `o200k_base.tiktoken`
Files are cached in the `data/` directory. The `TIKTOKEN_ENCODINGS_BASE` environment variable is automatically set to point to this directory when running evaluations.

View File

@@ -0,0 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
model_name: "openai/gpt-oss-20b"
metric_threshold: 0.568
reasoning_effort: "low"
server_args: "--tensor-parallel-size 2"

View File

@@ -0,0 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
model_name: "openai/gpt-oss-20b"
metric_threshold: 0.568
reasoning_effort: "low"
server_args: "--tensor-parallel-size 2"
env:
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: "1"

View File

@@ -0,0 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
model_name: "openai/gpt-oss-20b"
metric_threshold: 0.568
reasoning_effort: "low"
server_args: "--tensor-parallel-size 2"
env:
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: "1"

View File

@@ -0,0 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
model_name: "openai/gpt-oss-20b"
metric_threshold: 0.568
reasoning_effort: "low"
server_args: "--tensor-parallel-size 2"
env:
VLLM_MXFP4_USE_MARLIN: "1"

View File

@@ -0,0 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
model_name: "openai/gpt-oss-20b"
metric_threshold: 0.568
reasoning_effort: "low"
server_args: "--tensor-parallel-size 2"
env:
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: "1"

View File

@@ -0,0 +1,5 @@
# B200 model configurations for GPQA evaluation
# Tests different environment variable combinations
gpt-oss-20b-flashinfer-mxfp4-bf16.yaml
gpt-oss-20b-flashinfer-mxfp4-mxfp8-cutlass.yaml
gpt-oss-20b-sm100-fi-mxfp4-mxfp8-trtllm.yaml

View File

@@ -0,0 +1,5 @@
# H100 model configurations for GPQA evaluation
# Tests different environment variable combinations
gpt-oss-20b-baseline.yaml
gpt-oss-20b-flashinfer-mxfp4-bf16.yaml
gpt-oss-20b-marlin.yaml

View File

@@ -4,13 +4,61 @@
Pytest configuration for GPT-OSS evaluation tests.
"""
from pathlib import Path
def pytest_addoption(parser):
"""Add command line options for pytest."""
parser.addoption("--model", action="store", help="Model name to evaluate")
"""Add custom command line options."""
parser.addoption(
"--metric", action="store", type=float, help="Expected metric threshold"
)
parser.addoption(
"--server-args", action="store", default="", help="Additional server arguments"
"--config-list-file",
required=True,
help="File containing list of config files to test",
)
def pytest_generate_tests(metafunc):
"""Generate test parameters from config files."""
if "config_filename" in metafunc.fixturenames:
config_list_file = metafunc.config.getoption("--config-list-file")
# Handle both relative and absolute paths
config_list_path = Path(config_list_file)
if not config_list_path.is_absolute():
# If relative, try relative to test directory first
test_dir_path = Path(__file__).parent / config_list_file
if test_dir_path.exists():
config_list_path = test_dir_path
else:
# Try relative to current working directory
config_list_path = Path.cwd() / config_list_file
print(f"Looking for config list at: {config_list_path}")
config_files = []
if config_list_path.exists():
# Determine config directory (same directory as the list file)
config_dir = config_list_path.parent
with open(config_list_path) as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
config_path = config_dir / line
print(f"Checking config file: {config_path}")
if config_path.exists():
config_files.append(config_path)
print(f" Found: {config_path}")
else:
print(f" Missing: {config_path}")
else:
print(f"Config list file not found: {config_list_path}")
# Generate test parameters
if config_files:
metafunc.parametrize(
"config_filename",
config_files,
ids=[config_file.stem for config_file in config_files],
)
else:
print("No config files found, test will be skipped")

View File

@@ -5,22 +5,48 @@ GPQA evaluation using vLLM server and GPT-OSS evaluation package.
Usage:
pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py \
--model openai/gpt-oss-20b \
--metric 0.58 \
--server-args "--tensor-parallel-size 2"
--config-list-file=configs/models-h200.txt
"""
import os
import shlex
import subprocess
import sys
import urllib.request
from pathlib import Path
import regex as re
import yaml
from tests.utils import RemoteOpenAIServer
TOL = 0.05 # Absolute tolerance for accuracy comparison
# Path to tiktoken encoding files
TIKTOKEN_DATA_DIR = Path(__file__).parent / "data"
def run_gpqa_eval(model_name: str, base_url: str) -> float:
# Tiktoken encoding files to download
TIKTOKEN_FILES = {
"cl100k_base.tiktoken": "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken",
"o200k_base.tiktoken": "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken",
}
def ensure_tiktoken_files():
"""Download tiktoken encoding files if they don't exist."""
TIKTOKEN_DATA_DIR.mkdir(parents=True, exist_ok=True)
for filename, url in TIKTOKEN_FILES.items():
filepath = TIKTOKEN_DATA_DIR / filename
if not filepath.exists():
print(f"Downloading {filename} from {url}...")
urllib.request.urlretrieve(url, filepath)
print(f" Downloaded to {filepath}")
else:
print(f" {filename} already exists.")
def run_gpqa_eval(model_name: str, base_url: str, reasoning_effort: str) -> float:
"""Run GPQA evaluation using the gpt-oss evaluation package."""
# Build the command to run the evaluation
@@ -33,7 +59,7 @@ def run_gpqa_eval(model_name: str, base_url: str) -> float:
"--model",
model_name,
"--reasoning-effort",
"low",
reasoning_effort,
"--base-url",
base_url,
"--n-threads",
@@ -41,16 +67,29 @@ def run_gpqa_eval(model_name: str, base_url: str) -> float:
]
try:
# Set up environment for the evaluation subprocess
# Inherit current environment and add required variables
eval_env = os.environ.copy()
eval_env["OPENAI_API_KEY"] = "dummy"
# Run the evaluation
result = subprocess.run(
cmd,
text=True,
capture_output=True,
timeout=1800, # 30 minute timeout
env={"OPENAI_API_KEY": "dummy"},
env=eval_env,
)
print("Evaluation process output:\n", result.stdout)
print("Evaluation process stdout:\n", result.stdout)
print("Evaluation process stderr:\n", result.stderr)
print(f"Evaluation process return code: {result.returncode}")
if result.returncode != 0:
raise RuntimeError(
f"Evaluation failed with exit code {result.returncode}:\n"
f"stdout: {result.stdout}\nstderr: {result.stderr}"
)
# Parse the output to extract the score
match = re.search(r"'metric':\s*([\d.]+)", result.stdout)
@@ -64,47 +103,62 @@ def run_gpqa_eval(model_name: str, base_url: str) -> float:
except subprocess.TimeoutExpired as e:
raise RuntimeError("Evaluation timed out") from e
except subprocess.CalledProcessError as e:
raise RuntimeError(
f"Evaluation failed with exit code {e.returncode}:\n"
f"stdout: {e.stdout}\nstderr: {e.stderr}"
) from e
def test_gpqa_correctness(request):
"""Test GPQA correctness for GPT-OSS model."""
def test_gpqa_correctness(config_filename):
"""Test GPQA correctness for a given model configuration."""
# Ensure tiktoken files are downloaded
ensure_tiktoken_files()
# Get command line arguments
model_name = request.config.getoption("--model")
expected_metric = request.config.getoption("--metric")
server_args_str = request.config.getoption("--server-args")
# Verify tiktoken files exist
for filename in TIKTOKEN_FILES:
filepath = TIKTOKEN_DATA_DIR / filename
assert filepath.exists(), f"Tiktoken file not found: {filepath}"
# Parse server arguments
server_args = []
if server_args_str:
server_args = server_args_str.split()
eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
# Parse server arguments from config (use shlex to handle quoted strings)
server_args_str = eval_config.get("server_args", "")
server_args = shlex.split(server_args_str) if server_args_str else []
# Add standard server arguments
server_args.extend(
[
"--trust-remote-code",
"--enforce-eager",
"--disable-uvicorn-access-log",
]
)
print(f"Starting GPQA evaluation for model: {model_name}")
print(f"Expected metric threshold: {expected_metric}")
# Build server environment with tiktoken path and any config-specified vars
server_env = {"TIKTOKEN_ENCODINGS_BASE": str(TIKTOKEN_DATA_DIR)}
if eval_config.get("env"):
server_env.update(eval_config["env"])
reasoning_effort = eval_config.get("reasoning_effort", "low")
print(f"Starting GPQA evaluation for model: {eval_config['model_name']}")
print(f"Expected metric threshold: {eval_config['metric_threshold']}")
print(f"Reasoning effort: {reasoning_effort}")
print(f"Server args: {' '.join(server_args)}")
print(f"Server environment variables: {server_env}")
# Launch server and run evaluation
with RemoteOpenAIServer(
model_name, server_args, max_wait_seconds=1800
eval_config["model_name"],
server_args,
env_dict=server_env,
max_wait_seconds=eval_config.get("startup_max_wait_seconds", 1800),
) as remote_server:
base_url = remote_server.url_for("v1")
print(f"Server started at: {base_url}")
measured_metric = run_gpqa_eval(model_name, base_url)
measured_metric = run_gpqa_eval(
eval_config["model_name"], base_url, reasoning_effort
)
expected_metric = eval_config["metric_threshold"]
print(f"GPQA Results for {model_name}:")
print(f"GPQA Results for {eval_config['model_name']}:")
print(f" Measured metric: {measured_metric:.4f}")
print(f" Expected metric: {expected_metric:.4f}")
print(f" Tolerance: {TOL:.4f}")
@@ -115,4 +169,4 @@ def test_gpqa_correctness(request):
f"{expected_metric:.4f} - {TOL:.4f} = {expected_metric - TOL:.4f}"
)
print(f"GPQA test passed for {model_name}")
print(f"GPQA test passed for {eval_config['model_name']}")

View File

@@ -242,6 +242,10 @@ class FusedMoEQuantConfig:
def quant_dtype(self) -> torch.dtype | str | None:
return self._a1.dtype
@property
def weight_quant_dtype(self) -> torch.dtype | str | None:
return self._w1.dtype
@property
def is_quantized(self) -> bool:
return self.quant_dtype is not None

View File

@@ -4,6 +4,7 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
@@ -18,6 +19,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Dynamic128Sym,
kFp8Static128BlockSym,
kFp8StaticTensorSym,
kMxfp4Static,
kMxfp8Dynamic,
kNvfp4Dynamic,
kNvfp4Static,
)
@@ -64,10 +67,18 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
quant_config: FusedMoEQuantConfig,
):
super().__init__(moe_config, quant_config)
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
"Only nvfp4, fp8, bfloat16 and"
assert quant_config.weight_quant_dtype in (
"mxfp4",
"nvfp4",
torch.float8_e4m3fn,
None,
), (
"Only mxfp4, nvfp4, fp8, bfloat16 and"
" float16 quantization are currently supported."
)
self.device = moe_config.device
self.num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank
self.ep_size = moe_config.moe_parallel_config.ep_size
self.tp_rank = moe_config.moe_parallel_config.tp_rank
@@ -78,6 +89,28 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# - pass per-block weight scales to the kernel
# - skip input activation quantization (kernel applies scaling)
self.use_deepseek_fp8_block_scale = quant_config.is_block_quantized
self.max_capture_size = (
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
)
if quant_config.weight_quant_dtype == "mxfp4":
# This value is used specifically for gpt-oss,
# Need to revisit this for other models
self.gemm1_alpha = torch.tensor(
[1.702] * self.num_experts, dtype=torch.float32, device=self.device
)
self.gemm1_beta = torch.tensor(
[1.0] * self.num_experts, dtype=torch.float32, device=self.device
)
self.gemm1_clamp_limit = torch.tensor(
[7.0] * self.num_experts, dtype=torch.float32, device=self.device
)
if quant_config.quant_dtype == "mxfp8":
self.fake_input_scale = torch.ones(
self.num_experts,
device=self.device,
dtype=torch.float32,
)
@property
def expects_unquantized_inputs(self) -> bool:
@@ -119,20 +152,33 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
]
and p.has_device_capability(90)
)
# fp8 block-scale on 9.0
# fp8 block-scale, wmxfp4a16 on 9.0
or (
scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym)
scheme
in [
(kMxfp4Static, None),
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
]
and p.is_device_capability(90)
)
# nvfp4 on 10.0+
# nvfp4, wmxfp4amxfp8 on 10.0+
or (
scheme == (kNvfp4Static, kNvfp4Dynamic) and p.has_device_capability(100)
scheme
in [
(kMxfp4Static, kMxfp8Dynamic),
(kNvfp4Static, kNvfp4Dynamic),
]
and p.has_device_capability(100)
)
)
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
return activation in [
MoEActivation.SILU,
MoEActivation.RELU2_NO_MUL,
MoEActivation.SWIGLUOAI,
]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
@@ -216,12 +262,23 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation_str_to_value_map = {
MoEActivation.SILU: ActivationType.Swiglu, # This is the default
MoEActivation.SWIGLUOAI: ActivationType.Swiglu, # gpt-oss alias
MoEActivation.RELU2_NO_MUL: ActivationType.Relu2,
}
assert activation in activation_str_to_value_map, (
f"{activation=} missing from {activation_str_to_value_map.keys()=}"
)
quant_scales = None
fc1_expert_weights = None
fc2_expert_weights = None
fc1_expert_biases = None
fc2_expert_biases = None
swiglu_alpha = None
swiglu_beta = None
swiglu_limit = None
use_mxfp8_act_scaling = False
use_w4_group_scaling = False
# Select quantization metadata based on FP8 format/path
if (
self.quant_dtype == torch.float8_e4m3fn
@@ -256,6 +313,43 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# FlashInfer API requires weight to be long for nvfp4
fc1_expert_weights = w1.view(torch.long)
fc2_expert_weights = w2.view(torch.long)
elif self.weight_quant_dtype == "mxfp4":
assert self.w1_scale is not None and self.w2_scale is not None
assert w1.is_contiguous() and w2.is_contiguous()
assert self.gemm1_alpha is not None
assert self.gemm1_beta is not None
assert self.gemm1_clamp_limit is not None
assert topk_ids.is_contiguous()
fc1_expert_biases = self.w1_bias
fc2_expert_biases = self.w2_bias
swiglu_alpha = self.gemm1_alpha
swiglu_beta = self.gemm1_beta
swiglu_limit = self.gemm1_clamp_limit
if self.quant_dtype == "mxfp8":
assert self.fake_input_scale is not None
fc1_expert_weights = w1.view(torch.long)
fc2_expert_weights = w2.view(torch.long)
quant_scales = [
self.w1_scale.view(torch.int32),
self.fake_input_scale,
self.w2_scale.view(torch.int32),
self.fake_input_scale,
]
use_mxfp8_act_scaling = True
else:
assert hidden_states.dtype == torch.bfloat16
fc1_expert_weights = w1
fc2_expert_weights = w2
quant_scales = [
self.w1_scale,
self.w2_scale,
]
a1q_scale = None
use_w4_group_scaling = True
elif self.use_deepseek_fp8_block_scale:
# FP8 block-scale path: provide block-scale weights, omit a1q_scale
quant_scales = [
@@ -277,6 +371,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
token_final_scales=topk_weights,
fc1_expert_weights=fc1_expert_weights,
fc2_expert_weights=fc2_expert_weights,
fc1_expert_biases=fc1_expert_biases,
fc2_expert_biases=fc2_expert_biases,
swiglu_alpha=swiglu_alpha,
swiglu_beta=swiglu_beta,
swiglu_limit=swiglu_limit,
output=output,
output_dtype=self.out_dtype,
quant_scales=quant_scales,
input_sf=a1q_scale,
@@ -284,10 +384,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
tp_rank=self.tp_rank,
ep_size=self.ep_size,
ep_rank=self.ep_rank,
output=output,
activation_type=activation_str_to_value_map[activation],
# Informs FlashInfer to use the block-scale decoding path when True
use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale,
use_mxfp8_act_scaling=use_mxfp8_act_scaling,
use_w4_group_scaling=use_w4_group_scaling,
tune_max_num_tokens=max(self.max_capture_size, 1),
)
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:

View File

@@ -564,9 +564,13 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
#
@property
def quant_dtype(self) -> torch.dtype | None:
def quant_dtype(self) -> torch.dtype | str | None:
return self.quant_config.quant_dtype
@property
def weight_quant_dtype(self) -> torch.dtype | str | None:
return self.quant_config.weight_quant_dtype
@property
def block_shape(self) -> list[int] | None:
return self.quant_config.block_shape

View File

@@ -25,15 +25,20 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
gemm1_alpha,
gemm1_beta,
gemm1_clamp_limit,
max_capture_size,
):
super().__init__(moe_config, quant_config)
self.gemm1_alpha = gemm1_alpha
self.gemm1_beta = gemm1_beta
self.gemm1_clamp_limit = gemm1_clamp_limit
self.device = torch.cuda.current_device()
self.num_experts = moe_config.num_local_experts
self.gemm1_alpha = torch.tensor(
[1.702] * self.num_experts, dtype=torch.float32, device=self.device
)
self.gemm1_beta = torch.tensor(
[1.0] * self.num_experts, dtype=torch.float32, device=self.device
)
self.gemm1_clamp_limit = torch.tensor(
[7.0] * self.num_experts, dtype=torch.float32, device=self.device
)
self.max_capture_size = max_capture_size
@staticmethod

View File

@@ -195,11 +195,12 @@ def _mxfp8_e4m3_quantize(
A_scale: torch.Tensor | None,
per_act_token_quant: bool,
block_shape: list[int] | None = None,
is_sf_swizzled_layout: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
assert A_scale is None
assert not per_act_token_quant
assert block_shape is None
return mxfp8_e4m3_quantize(A)
return mxfp8_e4m3_quantize(A, is_sf_swizzled_layout)
def _mxfp6_e3m2_quantize(
@@ -275,7 +276,13 @@ def moe_kernel_quantize_input(
elif quant_dtype == "mxfp8":
# TODO: `quant_dtype == "mxfp8"` is ambiguous,
# should be fp8_e4m3. OCP MX also defines `fp8_e5m2`.
return _mxfp8_e4m3_quantize(A, A_scale, per_act_token_quant, block_shape)
return _mxfp8_e4m3_quantize(
A,
A_scale,
per_act_token_quant,
block_shape,
is_sf_swizzled_layout=is_fp4_scale_swizzled,
)
elif quant_dtype == "mxfp6_e3m2":
return _mxfp6_e3m2_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == "mxfp6_e2m3":

View File

@@ -256,6 +256,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
"Please check your environment and try again."
)
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
# Initialized in process_weights_after_loading for CUTLASS/SM90 backends
self.moe_mk: mk.FusedMoEModularKernel | None = None
def create_weights(
@@ -648,19 +649,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
):
layer.gemm1_alpha = Parameter(
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
)
layer.gemm1_beta = Parameter(
torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
)
layer.gemm1_clamp_limit = Parameter(
torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
)
sf_block_size = 32 # mxfp4 block size
# Common shape assertions
@@ -772,6 +760,30 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale = torch.nn.Parameter(
w2_scales_interleaved, requires_grad=False
)
# theses two kernels go through the `flashinfer_cutlass_fused_moe` path
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.moe_quant_config is not None
prepare_finalize = maybe_make_prepare_finalize(
moe=self.moe,
quant_config=self.moe_quant_config,
routing_tables=layer._maybe_init_expert_routing_tables(),
allow_new_interface=True,
)
assert prepare_finalize is not None
self.moe_mk = mk.FusedMoEModularKernel(
prepare_finalize,
FlashInferExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config,
),
shared_experts=None,
)
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
@@ -847,7 +859,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
elif self.mxfp4_backend in [Mxfp4Backend.SM100_FI_MXFP4_BF16]:
elif self.mxfp4_backend in [
Mxfp4Backend.SM100_FI_MXFP4_BF16,
Mxfp4Backend.SM90_FI_MXFP4_BF16,
]:
return mxfp4_w4a16_moe_quant_config(
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
@@ -897,9 +912,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
):
# B200 code-path
kwargs = {
"gemm1_alpha": layer.gemm1_alpha,
"gemm1_beta": layer.gemm1_beta,
"gemm1_clamp_limit": layer.gemm1_clamp_limit,
# TODO(bnell): part of quant_config
"max_capture_size": self.max_capture_size,
}
@@ -935,20 +947,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
if layer.enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4")
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
assert self.moe_mk is not None
return self.moe_mk(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
assert _can_support_mxfp4(
layer.use_grouped_topk,
layer.topk_group,
@@ -967,68 +965,22 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
assert (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
)
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
# Backend-specific preparation
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
from flashinfer import mxfp8_quantize
x_quant, x_scale = mxfp8_quantize(x, True, 32)
fake_input_scale = torch.ones(self.num_experts, device=x.device)
quant_scales = [
layer.w13_weight_scale.contiguous().view(torch.int32),
fake_input_scale,
layer.w2_weight_scale.contiguous().view(torch.int32),
fake_input_scale,
]
fi_input = x_quant
extra_kwargs = dict(
use_mxfp8_act_scaling=True,
input_sf=x_scale,
fc1_expert_weights=layer.w13_weight.contiguous().view(torch.long),
fc2_expert_weights=layer.w2_weight.contiguous().view(torch.long),
)
elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
assert x.dtype == torch.bfloat16
quant_scales = [
layer.w13_weight_scale,
layer.w2_weight_scale,
]
fi_input = x
extra_kwargs = dict(
use_w4_group_scaling=True,
fc1_expert_weights=layer.w13_weight,
fc2_expert_weights=layer.w2_weight,
)
output = torch.empty_like(x, dtype=torch.bfloat16)
flashinfer_cutlass_fused_moe(
input=fi_input,
token_selected_experts=topk_ids.to(torch.int).contiguous(),
token_final_scales=topk_weights,
output_dtype=torch.bfloat16,
output=output,
quant_scales=quant_scales,
fc1_expert_biases=layer.w13_bias,
fc2_expert_biases=layer.w2_bias,
swiglu_alpha=layer.gemm1_alpha,
swiglu_beta=layer.gemm1_beta,
swiglu_limit=layer.gemm1_clamp_limit,
tp_size=self.moe.tp_size,
tp_rank=self.moe.tp_rank,
ep_size=self.moe.ep_size,
ep_rank=self.moe.ep_rank,
tune_max_num_tokens=max(self.max_capture_size, 1),
**extra_kwargs,
or self.mxfp4_backend == Mxfp4Backend.MARLIN
)
return output
assert self.moe_mk is not None
return self.moe_mk(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=layer.expert_map,
shared_experts_input=shared_experts_input,
)
def apply_monolithic(
self,

View File

@@ -19,6 +19,7 @@ if TYPE_CHECKING:
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
MXFP_SCALE_DTYPE = torch.uint8
def get_fp8_min_max() -> tuple[float, float]:
@@ -151,6 +152,18 @@ kFp8Static128BlockSym = QuantKey(FP8_DTYPE, kStatic128BlockScale, symmetric=True
kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64))
kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True)
# TODO (zyongye): Convert all the torch.dtype to scale_dtype
# Changing that requires changing torch compile fused AR+Quant Quant key
# to avoid assertion error
kMxfp4DynamicGroupScale = ScaleDesc(MXFP_SCALE_DTYPE, False, GroupShape(1, 32))
kMxfp4Dynamic = QuantKey(FP4_DTYPE, scale=kMxfp4DynamicGroupScale, symmetric=True)
kMxfp8DynamicGroupScale = ScaleDesc(MXFP_SCALE_DTYPE, False, GroupShape(1, 32))
kMxfp8Dynamic = QuantKey(FP8_DTYPE, scale=kMxfp8DynamicGroupScale, symmetric=True)
kMxfp4StaticGroupScale = ScaleDesc(MXFP_SCALE_DTYPE, True, GroupShape(1, 32))
kMxfp4Static = QuantKey(FP4_DTYPE, scale=kMxfp4StaticGroupScale, symmetric=True)
# Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):