[MoE Refactor] MXFP4 Cutlass Experts to MK (#34542)
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
@@ -73,3 +73,29 @@ steps:
|
||||
num_devices: 2
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=evals/gsm8k/configs/moe-refactor-dp-ep/config-b200.txt
|
||||
|
||||
- label: GPQA Eval (GPT-OSS) (H100)
|
||||
timeout_in_minutes: 120
|
||||
device: h100
|
||||
optional: true
|
||||
num_devices: 2
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
- tests/evals/gpt_oss/
|
||||
commands:
|
||||
- uv pip install --system 'gpt-oss[eval]==0.0.5'
|
||||
- pytest -s -v evals/gpt_oss/test_gpqa_correctness.py --config-list-file=configs/models-h100.txt
|
||||
|
||||
- label: GPQA Eval (GPT-OSS) (B200)
|
||||
timeout_in_minutes: 120
|
||||
device: b200
|
||||
optional: true
|
||||
num_devices: 2
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
- tests/evals/gpt_oss/
|
||||
commands:
|
||||
- uv pip install --system 'gpt-oss[eval]==0.0.5'
|
||||
- pytest -s -v evals/gpt_oss/test_gpqa_correctness.py --config-list-file=configs/models-b200.txt
|
||||
|
||||
@@ -153,33 +153,6 @@ steps:
|
||||
- pytest -v -s transformers_utils
|
||||
- pytest -v -s config
|
||||
|
||||
- label: GPT-OSS Eval (H100)
|
||||
timeout_in_minutes: 60
|
||||
working_dir: "/vllm-workspace/"
|
||||
device: h100
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- tests/evals/gpt_oss
|
||||
- vllm/model_executor/models/gpt_oss.py
|
||||
- vllm/model_executor/layers/quantization/mxfp4.py
|
||||
commands:
|
||||
- uv pip install --system 'gpt-oss[eval]==0.0.5'
|
||||
- pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58
|
||||
|
||||
- label: GPT-OSS Eval (B200)
|
||||
timeout_in_minutes: 60
|
||||
working_dir: "/vllm-workspace/"
|
||||
device: b200
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- tests/evals/gpt_oss
|
||||
- vllm/model_executor/models/gpt_oss.py
|
||||
- vllm/model_executor/layers/quantization/mxfp4.py
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
commands:
|
||||
- uv pip install --system 'gpt-oss[eval]==0.0.5'
|
||||
- pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58
|
||||
|
||||
- label: Batch Invariance (H100)
|
||||
timeout_in_minutes: 25
|
||||
device: h100
|
||||
|
||||
49
tests/evals/gpt_oss/README.md
Normal file
49
tests/evals/gpt_oss/README.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# GPQA Evaluation using GPT-OSS
|
||||
|
||||
This directory contains GPQA evaluation tests using the GPT-OSS evaluation package and vLLM server.
|
||||
|
||||
## Usage
|
||||
|
||||
### Run tests with pytest (like buildkite)
|
||||
|
||||
```bash
|
||||
# H200
|
||||
pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py \
|
||||
--config-list-file=configs/models-h200.txt
|
||||
|
||||
# B200
|
||||
pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py \
|
||||
--config-list-file=configs/models-b200.txt
|
||||
```
|
||||
|
||||
## Configuration Format
|
||||
|
||||
Model configs in `configs/` directory use this YAML format:
|
||||
|
||||
```yaml
|
||||
model_name: "openai/gpt-oss-20b"
|
||||
metric_threshold: 0.568 # Minimum expected accuracy
|
||||
reasoning_effort: "low" # Reasoning effort level (default: "low")
|
||||
server_args: "--tensor-parallel-size 2" # Server arguments
|
||||
startup_max_wait_seconds: 1800 # Max wait for server startup (default: 1800)
|
||||
env: # Environment variables (optional)
|
||||
SOME_VAR: "value"
|
||||
```
|
||||
|
||||
The `server_args` field accepts any arguments that can be passed to `vllm serve`.
|
||||
|
||||
The `env` field accepts a dictionary of environment variables to set for the server process.
|
||||
|
||||
## Adding New Models
|
||||
|
||||
1. Create a new YAML config file in the `configs/` directory
|
||||
2. Add the filename to the appropriate `models-*.txt` file
|
||||
|
||||
## Tiktoken Encoding Files
|
||||
|
||||
The tiktoken encoding files required by the vLLM server are automatically downloaded from OpenAI's public blob storage on first run:
|
||||
|
||||
- `cl100k_base.tiktoken`
|
||||
- `o200k_base.tiktoken`
|
||||
|
||||
Files are cached in the `data/` directory. The `TIKTOKEN_ENCODINGS_BASE` environment variable is automatically set to point to this directory when running evaluations.
|
||||
6
tests/evals/gpt_oss/configs/gpt-oss-20b-baseline.yaml
Normal file
6
tests/evals/gpt_oss/configs/gpt-oss-20b-baseline.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
model_name: "openai/gpt-oss-20b"
|
||||
metric_threshold: 0.568
|
||||
reasoning_effort: "low"
|
||||
server_args: "--tensor-parallel-size 2"
|
||||
@@ -0,0 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
model_name: "openai/gpt-oss-20b"
|
||||
metric_threshold: 0.568
|
||||
reasoning_effort: "low"
|
||||
server_args: "--tensor-parallel-size 2"
|
||||
env:
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: "1"
|
||||
@@ -0,0 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
model_name: "openai/gpt-oss-20b"
|
||||
metric_threshold: 0.568
|
||||
reasoning_effort: "low"
|
||||
server_args: "--tensor-parallel-size 2"
|
||||
env:
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: "1"
|
||||
8
tests/evals/gpt_oss/configs/gpt-oss-20b-marlin.yaml
Normal file
8
tests/evals/gpt_oss/configs/gpt-oss-20b-marlin.yaml
Normal file
@@ -0,0 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
model_name: "openai/gpt-oss-20b"
|
||||
metric_threshold: 0.568
|
||||
reasoning_effort: "low"
|
||||
server_args: "--tensor-parallel-size 2"
|
||||
env:
|
||||
VLLM_MXFP4_USE_MARLIN: "1"
|
||||
@@ -0,0 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
model_name: "openai/gpt-oss-20b"
|
||||
metric_threshold: 0.568
|
||||
reasoning_effort: "low"
|
||||
server_args: "--tensor-parallel-size 2"
|
||||
env:
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: "1"
|
||||
5
tests/evals/gpt_oss/configs/models-b200.txt
Normal file
5
tests/evals/gpt_oss/configs/models-b200.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
# B200 model configurations for GPQA evaluation
|
||||
# Tests different environment variable combinations
|
||||
gpt-oss-20b-flashinfer-mxfp4-bf16.yaml
|
||||
gpt-oss-20b-flashinfer-mxfp4-mxfp8-cutlass.yaml
|
||||
gpt-oss-20b-sm100-fi-mxfp4-mxfp8-trtllm.yaml
|
||||
5
tests/evals/gpt_oss/configs/models-h100.txt
Normal file
5
tests/evals/gpt_oss/configs/models-h100.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
# H100 model configurations for GPQA evaluation
|
||||
# Tests different environment variable combinations
|
||||
gpt-oss-20b-baseline.yaml
|
||||
gpt-oss-20b-flashinfer-mxfp4-bf16.yaml
|
||||
gpt-oss-20b-marlin.yaml
|
||||
@@ -4,13 +4,61 @@
|
||||
Pytest configuration for GPT-OSS evaluation tests.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
"""Add command line options for pytest."""
|
||||
parser.addoption("--model", action="store", help="Model name to evaluate")
|
||||
"""Add custom command line options."""
|
||||
parser.addoption(
|
||||
"--metric", action="store", type=float, help="Expected metric threshold"
|
||||
)
|
||||
parser.addoption(
|
||||
"--server-args", action="store", default="", help="Additional server arguments"
|
||||
"--config-list-file",
|
||||
required=True,
|
||||
help="File containing list of config files to test",
|
||||
)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
"""Generate test parameters from config files."""
|
||||
if "config_filename" in metafunc.fixturenames:
|
||||
config_list_file = metafunc.config.getoption("--config-list-file")
|
||||
|
||||
# Handle both relative and absolute paths
|
||||
config_list_path = Path(config_list_file)
|
||||
if not config_list_path.is_absolute():
|
||||
# If relative, try relative to test directory first
|
||||
test_dir_path = Path(__file__).parent / config_list_file
|
||||
if test_dir_path.exists():
|
||||
config_list_path = test_dir_path
|
||||
else:
|
||||
# Try relative to current working directory
|
||||
config_list_path = Path.cwd() / config_list_file
|
||||
|
||||
print(f"Looking for config list at: {config_list_path}")
|
||||
|
||||
config_files = []
|
||||
if config_list_path.exists():
|
||||
# Determine config directory (same directory as the list file)
|
||||
config_dir = config_list_path.parent
|
||||
|
||||
with open(config_list_path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
config_path = config_dir / line
|
||||
print(f"Checking config file: {config_path}")
|
||||
if config_path.exists():
|
||||
config_files.append(config_path)
|
||||
print(f" Found: {config_path}")
|
||||
else:
|
||||
print(f" Missing: {config_path}")
|
||||
else:
|
||||
print(f"Config list file not found: {config_list_path}")
|
||||
|
||||
# Generate test parameters
|
||||
if config_files:
|
||||
metafunc.parametrize(
|
||||
"config_filename",
|
||||
config_files,
|
||||
ids=[config_file.stem for config_file in config_files],
|
||||
)
|
||||
else:
|
||||
print("No config files found, test will be skipped")
|
||||
|
||||
@@ -5,22 +5,48 @@ GPQA evaluation using vLLM server and GPT-OSS evaluation package.
|
||||
|
||||
Usage:
|
||||
pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py \
|
||||
--model openai/gpt-oss-20b \
|
||||
--metric 0.58 \
|
||||
--server-args "--tensor-parallel-size 2"
|
||||
--config-list-file=configs/models-h200.txt
|
||||
"""
|
||||
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
|
||||
import regex as re
|
||||
import yaml
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
TOL = 0.05 # Absolute tolerance for accuracy comparison
|
||||
|
||||
# Path to tiktoken encoding files
|
||||
TIKTOKEN_DATA_DIR = Path(__file__).parent / "data"
|
||||
|
||||
def run_gpqa_eval(model_name: str, base_url: str) -> float:
|
||||
# Tiktoken encoding files to download
|
||||
TIKTOKEN_FILES = {
|
||||
"cl100k_base.tiktoken": "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken",
|
||||
"o200k_base.tiktoken": "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken",
|
||||
}
|
||||
|
||||
|
||||
def ensure_tiktoken_files():
|
||||
"""Download tiktoken encoding files if they don't exist."""
|
||||
TIKTOKEN_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for filename, url in TIKTOKEN_FILES.items():
|
||||
filepath = TIKTOKEN_DATA_DIR / filename
|
||||
if not filepath.exists():
|
||||
print(f"Downloading {filename} from {url}...")
|
||||
urllib.request.urlretrieve(url, filepath)
|
||||
print(f" Downloaded to {filepath}")
|
||||
else:
|
||||
print(f" {filename} already exists.")
|
||||
|
||||
|
||||
def run_gpqa_eval(model_name: str, base_url: str, reasoning_effort: str) -> float:
|
||||
"""Run GPQA evaluation using the gpt-oss evaluation package."""
|
||||
|
||||
# Build the command to run the evaluation
|
||||
@@ -33,7 +59,7 @@ def run_gpqa_eval(model_name: str, base_url: str) -> float:
|
||||
"--model",
|
||||
model_name,
|
||||
"--reasoning-effort",
|
||||
"low",
|
||||
reasoning_effort,
|
||||
"--base-url",
|
||||
base_url,
|
||||
"--n-threads",
|
||||
@@ -41,16 +67,29 @@ def run_gpqa_eval(model_name: str, base_url: str) -> float:
|
||||
]
|
||||
|
||||
try:
|
||||
# Set up environment for the evaluation subprocess
|
||||
# Inherit current environment and add required variables
|
||||
eval_env = os.environ.copy()
|
||||
eval_env["OPENAI_API_KEY"] = "dummy"
|
||||
|
||||
# Run the evaluation
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
timeout=1800, # 30 minute timeout
|
||||
env={"OPENAI_API_KEY": "dummy"},
|
||||
env=eval_env,
|
||||
)
|
||||
|
||||
print("Evaluation process output:\n", result.stdout)
|
||||
print("Evaluation process stdout:\n", result.stdout)
|
||||
print("Evaluation process stderr:\n", result.stderr)
|
||||
print(f"Evaluation process return code: {result.returncode}")
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Evaluation failed with exit code {result.returncode}:\n"
|
||||
f"stdout: {result.stdout}\nstderr: {result.stderr}"
|
||||
)
|
||||
|
||||
# Parse the output to extract the score
|
||||
match = re.search(r"'metric':\s*([\d.]+)", result.stdout)
|
||||
@@ -64,47 +103,62 @@ def run_gpqa_eval(model_name: str, base_url: str) -> float:
|
||||
|
||||
except subprocess.TimeoutExpired as e:
|
||||
raise RuntimeError("Evaluation timed out") from e
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(
|
||||
f"Evaluation failed with exit code {e.returncode}:\n"
|
||||
f"stdout: {e.stdout}\nstderr: {e.stderr}"
|
||||
) from e
|
||||
|
||||
|
||||
def test_gpqa_correctness(request):
|
||||
"""Test GPQA correctness for GPT-OSS model."""
|
||||
def test_gpqa_correctness(config_filename):
|
||||
"""Test GPQA correctness for a given model configuration."""
|
||||
# Ensure tiktoken files are downloaded
|
||||
ensure_tiktoken_files()
|
||||
|
||||
# Get command line arguments
|
||||
model_name = request.config.getoption("--model")
|
||||
expected_metric = request.config.getoption("--metric")
|
||||
server_args_str = request.config.getoption("--server-args")
|
||||
# Verify tiktoken files exist
|
||||
for filename in TIKTOKEN_FILES:
|
||||
filepath = TIKTOKEN_DATA_DIR / filename
|
||||
assert filepath.exists(), f"Tiktoken file not found: {filepath}"
|
||||
|
||||
# Parse server arguments
|
||||
server_args = []
|
||||
if server_args_str:
|
||||
server_args = server_args_str.split()
|
||||
eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
|
||||
|
||||
# Parse server arguments from config (use shlex to handle quoted strings)
|
||||
server_args_str = eval_config.get("server_args", "")
|
||||
server_args = shlex.split(server_args_str) if server_args_str else []
|
||||
|
||||
# Add standard server arguments
|
||||
server_args.extend(
|
||||
[
|
||||
"--trust-remote-code",
|
||||
"--enforce-eager",
|
||||
"--disable-uvicorn-access-log",
|
||||
]
|
||||
)
|
||||
|
||||
print(f"Starting GPQA evaluation for model: {model_name}")
|
||||
print(f"Expected metric threshold: {expected_metric}")
|
||||
# Build server environment with tiktoken path and any config-specified vars
|
||||
server_env = {"TIKTOKEN_ENCODINGS_BASE": str(TIKTOKEN_DATA_DIR)}
|
||||
if eval_config.get("env"):
|
||||
server_env.update(eval_config["env"])
|
||||
|
||||
reasoning_effort = eval_config.get("reasoning_effort", "low")
|
||||
|
||||
print(f"Starting GPQA evaluation for model: {eval_config['model_name']}")
|
||||
print(f"Expected metric threshold: {eval_config['metric_threshold']}")
|
||||
print(f"Reasoning effort: {reasoning_effort}")
|
||||
print(f"Server args: {' '.join(server_args)}")
|
||||
print(f"Server environment variables: {server_env}")
|
||||
|
||||
# Launch server and run evaluation
|
||||
with RemoteOpenAIServer(
|
||||
model_name, server_args, max_wait_seconds=1800
|
||||
eval_config["model_name"],
|
||||
server_args,
|
||||
env_dict=server_env,
|
||||
max_wait_seconds=eval_config.get("startup_max_wait_seconds", 1800),
|
||||
) as remote_server:
|
||||
base_url = remote_server.url_for("v1")
|
||||
print(f"Server started at: {base_url}")
|
||||
|
||||
measured_metric = run_gpqa_eval(model_name, base_url)
|
||||
measured_metric = run_gpqa_eval(
|
||||
eval_config["model_name"], base_url, reasoning_effort
|
||||
)
|
||||
expected_metric = eval_config["metric_threshold"]
|
||||
|
||||
print(f"GPQA Results for {model_name}:")
|
||||
print(f"GPQA Results for {eval_config['model_name']}:")
|
||||
print(f" Measured metric: {measured_metric:.4f}")
|
||||
print(f" Expected metric: {expected_metric:.4f}")
|
||||
print(f" Tolerance: {TOL:.4f}")
|
||||
@@ -115,4 +169,4 @@ def test_gpqa_correctness(request):
|
||||
f"{expected_metric:.4f} - {TOL:.4f} = {expected_metric - TOL:.4f}"
|
||||
)
|
||||
|
||||
print(f"✅ GPQA test passed for {model_name}")
|
||||
print(f"GPQA test passed for {eval_config['model_name']}")
|
||||
|
||||
@@ -242,6 +242,10 @@ class FusedMoEQuantConfig:
|
||||
def quant_dtype(self) -> torch.dtype | str | None:
|
||||
return self._a1.dtype
|
||||
|
||||
@property
|
||||
def weight_quant_dtype(self) -> torch.dtype | str | None:
|
||||
return self._w1.dtype
|
||||
|
||||
@property
|
||||
def is_quantized(self) -> bool:
|
||||
return self.quant_dtype is not None
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
@@ -18,6 +19,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8Static128BlockSym,
|
||||
kFp8StaticTensorSym,
|
||||
kMxfp4Static,
|
||||
kMxfp8Dynamic,
|
||||
kNvfp4Dynamic,
|
||||
kNvfp4Static,
|
||||
)
|
||||
@@ -64,10 +67,18 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(moe_config, quant_config)
|
||||
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
|
||||
"Only nvfp4, fp8, bfloat16 and"
|
||||
|
||||
assert quant_config.weight_quant_dtype in (
|
||||
"mxfp4",
|
||||
"nvfp4",
|
||||
torch.float8_e4m3fn,
|
||||
None,
|
||||
), (
|
||||
"Only mxfp4, nvfp4, fp8, bfloat16 and"
|
||||
" float16 quantization are currently supported."
|
||||
)
|
||||
self.device = moe_config.device
|
||||
self.num_experts = moe_config.num_local_experts
|
||||
self.ep_rank = moe_config.moe_parallel_config.ep_rank
|
||||
self.ep_size = moe_config.moe_parallel_config.ep_size
|
||||
self.tp_rank = moe_config.moe_parallel_config.tp_rank
|
||||
@@ -78,6 +89,28 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# - pass per-block weight scales to the kernel
|
||||
# - skip input activation quantization (kernel applies scaling)
|
||||
self.use_deepseek_fp8_block_scale = quant_config.is_block_quantized
|
||||
self.max_capture_size = (
|
||||
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
|
||||
)
|
||||
|
||||
if quant_config.weight_quant_dtype == "mxfp4":
|
||||
# This value is used specifically for gpt-oss,
|
||||
# Need to revisit this for other models
|
||||
self.gemm1_alpha = torch.tensor(
|
||||
[1.702] * self.num_experts, dtype=torch.float32, device=self.device
|
||||
)
|
||||
self.gemm1_beta = torch.tensor(
|
||||
[1.0] * self.num_experts, dtype=torch.float32, device=self.device
|
||||
)
|
||||
self.gemm1_clamp_limit = torch.tensor(
|
||||
[7.0] * self.num_experts, dtype=torch.float32, device=self.device
|
||||
)
|
||||
if quant_config.quant_dtype == "mxfp8":
|
||||
self.fake_input_scale = torch.ones(
|
||||
self.num_experts,
|
||||
device=self.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
@@ -119,20 +152,33 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
]
|
||||
and p.has_device_capability(90)
|
||||
)
|
||||
# fp8 block-scale on 9.0
|
||||
# fp8 block-scale, wmxfp4a16 on 9.0
|
||||
or (
|
||||
scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym)
|
||||
scheme
|
||||
in [
|
||||
(kMxfp4Static, None),
|
||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
||||
]
|
||||
and p.is_device_capability(90)
|
||||
)
|
||||
# nvfp4 on 10.0+
|
||||
# nvfp4, wmxfp4amxfp8 on 10.0+
|
||||
or (
|
||||
scheme == (kNvfp4Static, kNvfp4Dynamic) and p.has_device_capability(100)
|
||||
scheme
|
||||
in [
|
||||
(kMxfp4Static, kMxfp8Dynamic),
|
||||
(kNvfp4Static, kNvfp4Dynamic),
|
||||
]
|
||||
and p.has_device_capability(100)
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
return activation in [
|
||||
MoEActivation.SILU,
|
||||
MoEActivation.RELU2_NO_MUL,
|
||||
MoEActivation.SWIGLUOAI,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
@@ -216,12 +262,23 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
activation_str_to_value_map = {
|
||||
MoEActivation.SILU: ActivationType.Swiglu, # This is the default
|
||||
MoEActivation.SWIGLUOAI: ActivationType.Swiglu, # gpt-oss alias
|
||||
MoEActivation.RELU2_NO_MUL: ActivationType.Relu2,
|
||||
}
|
||||
assert activation in activation_str_to_value_map, (
|
||||
f"{activation=} missing from {activation_str_to_value_map.keys()=}"
|
||||
)
|
||||
|
||||
quant_scales = None
|
||||
fc1_expert_weights = None
|
||||
fc2_expert_weights = None
|
||||
fc1_expert_biases = None
|
||||
fc2_expert_biases = None
|
||||
swiglu_alpha = None
|
||||
swiglu_beta = None
|
||||
swiglu_limit = None
|
||||
use_mxfp8_act_scaling = False
|
||||
use_w4_group_scaling = False
|
||||
# Select quantization metadata based on FP8 format/path
|
||||
if (
|
||||
self.quant_dtype == torch.float8_e4m3fn
|
||||
@@ -256,6 +313,43 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# FlashInfer API requires weight to be long for nvfp4
|
||||
fc1_expert_weights = w1.view(torch.long)
|
||||
fc2_expert_weights = w2.view(torch.long)
|
||||
elif self.weight_quant_dtype == "mxfp4":
|
||||
assert self.w1_scale is not None and self.w2_scale is not None
|
||||
assert w1.is_contiguous() and w2.is_contiguous()
|
||||
assert self.gemm1_alpha is not None
|
||||
assert self.gemm1_beta is not None
|
||||
assert self.gemm1_clamp_limit is not None
|
||||
assert topk_ids.is_contiguous()
|
||||
|
||||
fc1_expert_biases = self.w1_bias
|
||||
fc2_expert_biases = self.w2_bias
|
||||
swiglu_alpha = self.gemm1_alpha
|
||||
swiglu_beta = self.gemm1_beta
|
||||
swiglu_limit = self.gemm1_clamp_limit
|
||||
|
||||
if self.quant_dtype == "mxfp8":
|
||||
assert self.fake_input_scale is not None
|
||||
fc1_expert_weights = w1.view(torch.long)
|
||||
fc2_expert_weights = w2.view(torch.long)
|
||||
|
||||
quant_scales = [
|
||||
self.w1_scale.view(torch.int32),
|
||||
self.fake_input_scale,
|
||||
self.w2_scale.view(torch.int32),
|
||||
self.fake_input_scale,
|
||||
]
|
||||
use_mxfp8_act_scaling = True
|
||||
else:
|
||||
assert hidden_states.dtype == torch.bfloat16
|
||||
fc1_expert_weights = w1
|
||||
fc2_expert_weights = w2
|
||||
quant_scales = [
|
||||
self.w1_scale,
|
||||
self.w2_scale,
|
||||
]
|
||||
a1q_scale = None
|
||||
use_w4_group_scaling = True
|
||||
|
||||
elif self.use_deepseek_fp8_block_scale:
|
||||
# FP8 block-scale path: provide block-scale weights, omit a1q_scale
|
||||
quant_scales = [
|
||||
@@ -277,6 +371,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
token_final_scales=topk_weights,
|
||||
fc1_expert_weights=fc1_expert_weights,
|
||||
fc2_expert_weights=fc2_expert_weights,
|
||||
fc1_expert_biases=fc1_expert_biases,
|
||||
fc2_expert_biases=fc2_expert_biases,
|
||||
swiglu_alpha=swiglu_alpha,
|
||||
swiglu_beta=swiglu_beta,
|
||||
swiglu_limit=swiglu_limit,
|
||||
output=output,
|
||||
output_dtype=self.out_dtype,
|
||||
quant_scales=quant_scales,
|
||||
input_sf=a1q_scale,
|
||||
@@ -284,10 +384,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
tp_rank=self.tp_rank,
|
||||
ep_size=self.ep_size,
|
||||
ep_rank=self.ep_rank,
|
||||
output=output,
|
||||
activation_type=activation_str_to_value_map[activation],
|
||||
# Informs FlashInfer to use the block-scale decoding path when True
|
||||
use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale,
|
||||
use_mxfp8_act_scaling=use_mxfp8_act_scaling,
|
||||
use_w4_group_scaling=use_w4_group_scaling,
|
||||
tune_max_num_tokens=max(self.max_capture_size, 1),
|
||||
)
|
||||
|
||||
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
|
||||
|
||||
@@ -564,9 +564,13 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
#
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> torch.dtype | None:
|
||||
def quant_dtype(self) -> torch.dtype | str | None:
|
||||
return self.quant_config.quant_dtype
|
||||
|
||||
@property
|
||||
def weight_quant_dtype(self) -> torch.dtype | str | None:
|
||||
return self.quant_config.weight_quant_dtype
|
||||
|
||||
@property
|
||||
def block_shape(self) -> list[int] | None:
|
||||
return self.quant_config.block_shape
|
||||
|
||||
@@ -25,15 +25,20 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
gemm1_alpha,
|
||||
gemm1_beta,
|
||||
gemm1_clamp_limit,
|
||||
max_capture_size,
|
||||
):
|
||||
super().__init__(moe_config, quant_config)
|
||||
self.gemm1_alpha = gemm1_alpha
|
||||
self.gemm1_beta = gemm1_beta
|
||||
self.gemm1_clamp_limit = gemm1_clamp_limit
|
||||
self.device = torch.cuda.current_device()
|
||||
self.num_experts = moe_config.num_local_experts
|
||||
self.gemm1_alpha = torch.tensor(
|
||||
[1.702] * self.num_experts, dtype=torch.float32, device=self.device
|
||||
)
|
||||
self.gemm1_beta = torch.tensor(
|
||||
[1.0] * self.num_experts, dtype=torch.float32, device=self.device
|
||||
)
|
||||
self.gemm1_clamp_limit = torch.tensor(
|
||||
[7.0] * self.num_experts, dtype=torch.float32, device=self.device
|
||||
)
|
||||
self.max_capture_size = max_capture_size
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -195,11 +195,12 @@ def _mxfp8_e4m3_quantize(
|
||||
A_scale: torch.Tensor | None,
|
||||
per_act_token_quant: bool,
|
||||
block_shape: list[int] | None = None,
|
||||
is_sf_swizzled_layout: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert A_scale is None
|
||||
assert not per_act_token_quant
|
||||
assert block_shape is None
|
||||
return mxfp8_e4m3_quantize(A)
|
||||
return mxfp8_e4m3_quantize(A, is_sf_swizzled_layout)
|
||||
|
||||
|
||||
def _mxfp6_e3m2_quantize(
|
||||
@@ -275,7 +276,13 @@ def moe_kernel_quantize_input(
|
||||
elif quant_dtype == "mxfp8":
|
||||
# TODO: `quant_dtype == "mxfp8"` is ambiguous,
|
||||
# should be fp8_e4m3. OCP MX also defines `fp8_e5m2`.
|
||||
return _mxfp8_e4m3_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||
return _mxfp8_e4m3_quantize(
|
||||
A,
|
||||
A_scale,
|
||||
per_act_token_quant,
|
||||
block_shape,
|
||||
is_sf_swizzled_layout=is_fp4_scale_swizzled,
|
||||
)
|
||||
elif quant_dtype == "mxfp6_e3m2":
|
||||
return _mxfp6_e3m2_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||
elif quant_dtype == "mxfp6_e2m3":
|
||||
|
||||
@@ -256,6 +256,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
"Please check your environment and try again."
|
||||
)
|
||||
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||
# Initialized in process_weights_after_loading for CUTLASS/SM90 backends
|
||||
self.moe_mk: mk.FusedMoEModularKernel | None = None
|
||||
|
||||
def create_weights(
|
||||
@@ -648,19 +649,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
|
||||
):
|
||||
layer.gemm1_alpha = Parameter(
|
||||
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.gemm1_beta = Parameter(
|
||||
torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.gemm1_clamp_limit = Parameter(
|
||||
torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
sf_block_size = 32 # mxfp4 block size
|
||||
|
||||
# Common shape assertions
|
||||
@@ -772,6 +760,30 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
w2_scales_interleaved, requires_grad=False
|
||||
)
|
||||
|
||||
# theses two kernels go through the `flashinfer_cutlass_fused_moe` path
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
)
|
||||
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
assert self.moe_quant_config is not None
|
||||
prepare_finalize = maybe_make_prepare_finalize(
|
||||
moe=self.moe,
|
||||
quant_config=self.moe_quant_config,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
allow_new_interface=True,
|
||||
)
|
||||
assert prepare_finalize is not None
|
||||
|
||||
self.moe_mk = mk.FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
FlashInferExperts(
|
||||
moe_config=self.moe,
|
||||
quant_config=self.moe_quant_config,
|
||||
),
|
||||
shared_experts=None,
|
||||
)
|
||||
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
@@ -847,7 +859,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
)
|
||||
elif self.mxfp4_backend in [Mxfp4Backend.SM100_FI_MXFP4_BF16]:
|
||||
elif self.mxfp4_backend in [
|
||||
Mxfp4Backend.SM100_FI_MXFP4_BF16,
|
||||
Mxfp4Backend.SM90_FI_MXFP4_BF16,
|
||||
]:
|
||||
return mxfp4_w4a16_moe_quant_config(
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
@@ -897,9 +912,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
):
|
||||
# B200 code-path
|
||||
kwargs = {
|
||||
"gemm1_alpha": layer.gemm1_alpha,
|
||||
"gemm1_beta": layer.gemm1_beta,
|
||||
"gemm1_clamp_limit": layer.gemm1_clamp_limit,
|
||||
# TODO(bnell): part of quant_config
|
||||
"max_capture_size": self.max_capture_size,
|
||||
}
|
||||
@@ -935,20 +947,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
assert self.moe_mk is not None
|
||||
|
||||
return self.moe_mk(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
assert _can_support_mxfp4(
|
||||
layer.use_grouped_topk,
|
||||
layer.topk_group,
|
||||
@@ -967,68 +965,22 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
assert (
|
||||
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
|
||||
)
|
||||
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
||||
|
||||
# Backend-specific preparation
|
||||
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
|
||||
from flashinfer import mxfp8_quantize
|
||||
|
||||
x_quant, x_scale = mxfp8_quantize(x, True, 32)
|
||||
|
||||
fake_input_scale = torch.ones(self.num_experts, device=x.device)
|
||||
quant_scales = [
|
||||
layer.w13_weight_scale.contiguous().view(torch.int32),
|
||||
fake_input_scale,
|
||||
layer.w2_weight_scale.contiguous().view(torch.int32),
|
||||
fake_input_scale,
|
||||
]
|
||||
|
||||
fi_input = x_quant
|
||||
extra_kwargs = dict(
|
||||
use_mxfp8_act_scaling=True,
|
||||
input_sf=x_scale,
|
||||
fc1_expert_weights=layer.w13_weight.contiguous().view(torch.long),
|
||||
fc2_expert_weights=layer.w2_weight.contiguous().view(torch.long),
|
||||
)
|
||||
elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
|
||||
assert x.dtype == torch.bfloat16
|
||||
|
||||
quant_scales = [
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
]
|
||||
|
||||
fi_input = x
|
||||
extra_kwargs = dict(
|
||||
use_w4_group_scaling=True,
|
||||
fc1_expert_weights=layer.w13_weight,
|
||||
fc2_expert_weights=layer.w2_weight,
|
||||
)
|
||||
|
||||
output = torch.empty_like(x, dtype=torch.bfloat16)
|
||||
|
||||
flashinfer_cutlass_fused_moe(
|
||||
input=fi_input,
|
||||
token_selected_experts=topk_ids.to(torch.int).contiguous(),
|
||||
token_final_scales=topk_weights,
|
||||
output_dtype=torch.bfloat16,
|
||||
output=output,
|
||||
quant_scales=quant_scales,
|
||||
fc1_expert_biases=layer.w13_bias,
|
||||
fc2_expert_biases=layer.w2_bias,
|
||||
swiglu_alpha=layer.gemm1_alpha,
|
||||
swiglu_beta=layer.gemm1_beta,
|
||||
swiglu_limit=layer.gemm1_clamp_limit,
|
||||
tp_size=self.moe.tp_size,
|
||||
tp_rank=self.moe.tp_rank,
|
||||
ep_size=self.moe.ep_size,
|
||||
ep_rank=self.moe.ep_rank,
|
||||
tune_max_num_tokens=max(self.max_capture_size, 1),
|
||||
**extra_kwargs,
|
||||
or self.mxfp4_backend == Mxfp4Backend.MARLIN
|
||||
)
|
||||
|
||||
return output
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
expert_map=layer.expert_map,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
|
||||
@@ -19,6 +19,7 @@ if TYPE_CHECKING:
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
MXFP_SCALE_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def get_fp8_min_max() -> tuple[float, float]:
|
||||
@@ -151,6 +152,18 @@ kFp8Static128BlockSym = QuantKey(FP8_DTYPE, kStatic128BlockScale, symmetric=True
|
||||
kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64))
|
||||
kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True)
|
||||
|
||||
# TODO (zyongye): Convert all the torch.dtype to scale_dtype
|
||||
# Changing that requires changing torch compile fused AR+Quant Quant key
|
||||
# to avoid assertion error
|
||||
kMxfp4DynamicGroupScale = ScaleDesc(MXFP_SCALE_DTYPE, False, GroupShape(1, 32))
|
||||
kMxfp4Dynamic = QuantKey(FP4_DTYPE, scale=kMxfp4DynamicGroupScale, symmetric=True)
|
||||
|
||||
kMxfp8DynamicGroupScale = ScaleDesc(MXFP_SCALE_DTYPE, False, GroupShape(1, 32))
|
||||
kMxfp8Dynamic = QuantKey(FP8_DTYPE, scale=kMxfp8DynamicGroupScale, symmetric=True)
|
||||
|
||||
kMxfp4StaticGroupScale = ScaleDesc(MXFP_SCALE_DTYPE, True, GroupShape(1, 32))
|
||||
kMxfp4Static = QuantKey(FP4_DTYPE, scale=kMxfp4StaticGroupScale, symmetric=True)
|
||||
|
||||
|
||||
# Normalize the group_shape to the full extent for any dims that are -1
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
|
||||
|
||||
Reference in New Issue
Block a user