[Frontend] new online quantization frontend (#38138)

Signed-off-by: Vasiliy Kuznetsov <vasiliy@meta.com>
This commit is contained in:
Vasiliy Kuznetsov
2026-04-03 11:58:39 -04:00
committed by GitHub
parent 97f92c6b47
commit 7b1a7423be
13 changed files with 1205 additions and 0 deletions

View File

@@ -0,0 +1,179 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests online quantization."""
import pytest
import torch
from tests.quantization.utils import (
_test_online_quant_peak_mem_impl,
is_quant_method_supported,
)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.online.fp8 import (
Fp8PerBlockOnlineLinearMethod,
Fp8PerBlockOnlineMoEMethod,
Fp8PerTensorOnlineLinearMethod,
Fp8PerTensorOnlineMoEMethod,
)
from vllm.platforms import current_platform
@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
@pytest.mark.parametrize(
"quant_scheme,online_quant_args,expected_linear_cls,expected_moe_cls",
[
# simple case - quantization='fp8_per_tensor'
(
"fp8_per_tensor",
None,
Fp8PerTensorOnlineLinearMethod,
Fp8PerTensorOnlineMoEMethod,
),
# simple case - quantization='fp8_per_block'
(
"fp8_per_block",
None,
Fp8PerBlockOnlineLinearMethod,
Fp8PerBlockOnlineMoEMethod,
),
# quantization='online with linear_scheme_override and
# moe_scheme_override
(
"online",
{
"linear_scheme_override": "fp8_per_block",
"moe_scheme_override": "fp8_per_tensor",
},
Fp8PerBlockOnlineLinearMethod,
Fp8PerTensorOnlineMoEMethod,
),
# ignore with direct layer name
(
"fp8_per_tensor",
# qkv_proj is fused from q_proj/k_proj/v_proj, so currently the
# ignore regex must match the unfused shard names
# TODO(future PR): also make 're:.*qkv_proj.*' work
{"ignore": ["model.layers.1.self_attn.o_proj", "re:.*[qkv]_proj"]},
Fp8PerTensorOnlineLinearMethod,
Fp8PerTensorOnlineMoEMethod,
),
],
)
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
def test_online_quantization(
vllm_runner,
quant_scheme: str,
online_quant_args: dict | None,
expected_linear_cls,
expected_moe_cls,
use_rocm_aiter: bool,
monkeypatch,
) -> None:
"""
Tests that online quantization frontend configuration works -
selecting quant schemes, overriding quant schemes by type, ignoring
layers.
Does not test performance, peak memory usage, etc.
"""
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# `LLM.apply_model` requires pickling a function.
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
# a tiny model with both dense and MoE layers
model_name = "ibm-granite/granite-3.0-1b-a400m-base"
runner_kwargs = dict(
quantization=quant_scheme,
enforce_eager=True,
)
if online_quant_args is not None:
runner_kwargs["quantization_config"] = online_quant_args
with vllm_runner(
model_name,
**runner_kwargs,
) as llm:
def check_model(model):
# checks further down in the test case are hardcoded for this
# model
assert model_name == "ibm-granite/granite-3.0-1b-a400m-base"
o_proj = model.model.layers[0].self_attn.o_proj
moe = model.model.layers[0].block_sparse_moe.experts
# o_proj and moe in layer 0 are always quantized (never ignored)
# because of how we craft the test case inputs
assert isinstance(o_proj.quant_method, expected_linear_cls)
if moe is not None:
assert isinstance(moe.quant_method, expected_moe_cls)
if current_platform.is_cuda():
assert o_proj.weight.dtype == torch.float8_e4m3fn
elif current_platform.is_rocm():
assert o_proj.weight.dtype == current_platform.fp8_dtype()
else:
pytest.skip("Only runs on CUDA and ROCm.")
# Verify ignored layers are unquantized.
if isinstance(online_quant_args, dict) and "ignore" in online_quant_args:
# only .*1.self_attn_o_proj is skipped
for layer_idx in range(len(model.model.layers)):
o_proj = model.model.layers[layer_idx].self_attn.o_proj
if layer_idx == 1:
assert isinstance(o_proj.quant_method, UnquantizedLinearMethod)
else:
assert isinstance(o_proj.quant_method, expected_linear_cls)
# every .*self_attn.qkv_proj is skipped
for layer_idx in range(len(model.model.layers)):
qkv_proj = model.model.layers[layer_idx].self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, UnquantizedLinearMethod)
llm.apply_model(check_model)
outputs = llm.generate_greedy(["Hello my name is"], max_tokens=4)
print(outputs[0][1])
@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
def test_online_quant_peak_mem(
vllm_runner,
caplog_mp_spawn,
monkeypatch,
) -> None:
_test_online_quant_peak_mem_impl(
"fp8_per_tensor", vllm_runner, caplog_mp_spawn, monkeypatch
)
@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
def test_online_quant_load_format_dummy(
vllm_runner,
monkeypatch,
caplog,
) -> None:
with vllm_runner(
"ibm-granite/granite-3.0-1b-a400m-base",
quantization="fp8_per_tensor",
enforce_eager=True,
load_format="dummy",
) as llm:
outputs = llm.generate_greedy(["The future of AI is"], max_tokens=4)
print(outputs[0][1])

View File

@@ -1,6 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import regex as re
from vllm.model_executor.layers.quantization import get_quantization_config
from vllm.platforms import current_platform
@@ -21,3 +25,74 @@ def is_quant_method_supported(quant_method: str) -> bool:
min_capability = get_quantization_config(quant_method).get_min_capability()
return capability.to_int() >= min_capability
def _test_online_quant_peak_mem_impl(
quantization_arg_value,
vllm_runner,
caplog_mp_spawn,
monkeypatch,
) -> None:
# Note: `allenai/OLMoE-1B-7B-0125-Instruct` was selected because:
# 1. it covers both Linear and MoE paths
# 2. it is already used by other tests in CI, so adding it here
# does not increase disk space for CI runners
# I really wanted to use `ibm-granite/granite-3.0-1b-a400m-base`
# which I think is the smallest MoE model in vLLM (2.5 GiB bf16,
# 1.3 GiB fp8), but could not as adding one more model makes CI
# run out of disk space.
model_name = "allenai/OLMoE-1B-7B-0125-Instruct"
# Force spawn to ensure caplog_mp_spawn works consistently
# (it relies on VLLM_LOGGING_CONFIG_PATH which spawn reads but fork ignores)
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
with (
caplog_mp_spawn(logging.DEBUG) as log_holder,
vllm_runner(
model_name,
quantization=quantization_arg_value,
enforce_eager=True,
) as llm,
):
outputs = llm.generate_greedy(["The future of AI is"], max_tokens=4)
print(outputs[0][1])
log_text = log_holder.text
# Parse memory usage from captured logs
model_memory_gib = None
peak_memory_gib = None
for line in log_text.splitlines():
if model_memory_gib is None:
match = re.search(r"Model loading took ([\d.]+) GiB memory", line)
if match:
model_memory_gib = float(match.group(1))
if peak_memory_gib is None:
match = re.search(
r"Peak GPU memory after loading weights: ([\d.]+) GiB", line
)
if match:
peak_memory_gib = float(match.group(1))
assert model_memory_gib is not None, "Could not find model loading memory log"
assert peak_memory_gib is not None, "Could not find peak memory log"
print(f"GPU memory used after loading weights: {model_memory_gib} GiB")
print(f"Peak GPU memory usage while loading weights: {peak_memory_gib} GiB")
# model specific, allenai/OLMoE-1B-7B-0125-Instruct fp8 online quant
# uses 6.65 GiB for weight loading (bf16 checkpoint is ~12.89 GiB)
expected_model_memory_gib = 6.7
# for allenai/OLMoE-1B-7B-0125-Instruct the number we see today is 9.06
# GiB, which is 1.36x above model_memory_gib. A slightly higher number is
# expected as when we load and quantize weights in a streaming fashion we
# need to have individual weights in bf16 + fp8 alive at the same time.
expected_peak_memory_gib = expected_model_memory_gib * 1.4
assert model_memory_gib < expected_model_memory_gib, (
f"{model_memory_gib=} higher than {expected_model_memory_gib}"
)
assert peak_memory_gib < expected_peak_memory_gib, (
f"{peak_memory_gib=} higher than {expected_peak_memory_gib}"
)

View File

@@ -21,6 +21,7 @@ from vllm.config.multimodal import (
MultiModalConfig,
)
from vllm.config.pooler import PoolerConfig
from vllm.config.quantization import OnlineQuantizationConfigArgs
from vllm.config.scheduler import RunnerType
from vllm.config.utils import config, getattr_iter
from vllm.logger import init_logger
@@ -199,6 +200,10 @@ class ModelConfig:
`quantization_config` attribute in the model config file. If that is
`None`, we assume the model weights are not quantized and use `dtype` to
determine the data type of the weights."""
quantization_config: dict[str, Any] | OnlineQuantizationConfigArgs | None = None
"""Arguments for online quantization.
Auto-created when `quantization` equals to one of the string values of
the `OnlineQuantScheme` enum."""
allow_deprecated_quantization: bool = False
"""Whether to allow deprecated quantization methods."""
enforce_eager: bool = False

121
vllm/config/quantization.py Normal file
View File

@@ -0,0 +1,121 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import Any
from pydantic import Field, field_validator
from vllm.config.utils import config
class OnlineQuantScheme(Enum):
"""Supported online quantization schemes."""
# fp8, weights and activations scaled per-tensor
FP8_PER_TENSOR = "fp8_per_tensor"
# fp8, activations scaled in blocks of 1x128 elements, weights scaled in
# blocks of 128x128 elements (popularized by DeepSeek)
FP8_PER_BLOCK = "fp8_per_block"
# TODO(future PRs): add more online quant schemes here: mxfp8, etc
@config
class OnlineQuantizationConfigArgs:
"""Configuration for online quantization.
Controls how ``OnlineQuantizationConfig`` is applied to a model.
At least one of ``global_scheme``, ``linear_scheme_override``, or
``moe_scheme_override`` must be set.
"""
global_scheme: OnlineQuantScheme | None = None
"""Quantization scheme applied to every supported layer."""
linear_scheme_override: OnlineQuantScheme | None = None
"""Quantization scheme override for ``LinearBase`` layers."""
moe_scheme_override: OnlineQuantScheme | None = None
"""Quantization scheme override for ``FusedMoE`` layers."""
ignore: list[str] = Field(default_factory=list)
"""Layers to skip quantization for. Supports exact names and regex
patterns with ``re:`` prefix (e.g. ``re:.*attn.*``), consistent with
compressed_tensors layer skipping."""
@field_validator(
"global_scheme", "linear_scheme_override", "moe_scheme_override", mode="before"
)
@classmethod
def _coerce_scheme(
cls, v: str | OnlineQuantScheme | None
) -> OnlineQuantScheme | None:
if isinstance(v, str):
return OnlineQuantScheme(v)
return v
def resolve_online_quant_config(
quantization: str | None,
quantization_config: dict[str, Any] | OnlineQuantizationConfigArgs | None,
) -> OnlineQuantizationConfigArgs | None:
"""Resolve online quant scheme shorthand into a quantization config.
If ``quantization`` is an online quant scheme (e.g. ``'fp8_per_tensor'``),
ensures ``quantization_config`` has a matching ``global_scheme`` and casts
it to :class:`OnlineQuantizationConfigArgs` if needed.
"""
online_quant_values = {s.value for s in OnlineQuantScheme}
valid_quantization_values = online_quant_values | {"online"}
if quantization not in valid_quantization_values:
if quantization_config is not None:
raise ValueError(
f"quantization_config is only supported when quantization "
f"is one of {sorted(valid_quantization_values)}, "
f"got quantization={quantization!r}"
)
return None
if quantization in online_quant_values:
scheme = OnlineQuantScheme(quantization)
if quantization_config is None:
quantization_config = {
"global_scheme": scheme.value,
}
elif isinstance(quantization_config, OnlineQuantizationConfigArgs):
if quantization_config.global_scheme is None:
quantization_config.global_scheme = scheme
elif quantization_config.global_scheme != scheme:
raise ValueError(
f"quantization={quantization!r} conflicts with "
f"quantization_config.global_scheme="
f"{quantization_config.global_scheme.value!r}. "
f"These must match when both are specified."
)
elif isinstance(quantization_config, dict):
existing = quantization_config.get("global_scheme")
if existing is None:
quantization_config["global_scheme"] = scheme.value
else:
# Coerce to enum for comparison
existing_scheme = (
OnlineQuantScheme(existing)
if isinstance(existing, str)
else existing
)
if existing_scheme != scheme:
raise ValueError(
f"quantization={quantization!r} conflicts "
f"with quantization_config"
f"['global_scheme']={existing!r}. "
f"These must match when both are specified."
)
# Cast dict to OnlineQuantizationConfigArgs
if isinstance(quantization_config, dict):
quantization_config = OnlineQuantizationConfigArgs(**quantization_config)
return quantization_config

View File

@@ -1713,6 +1713,7 @@ class VllmConfig:
f"dcp_comm_backend={self.parallel_config.dcp_comm_backend}, " # noqa
f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa
f"quantization={self.model_config.quantization}, "
f"quantization_config={self.model_config.quantization_config}, " # noqa
f"enforce_eager={self.model_config.enforce_eager}, "
f"enable_return_routed_experts={self.model_config.enable_return_routed_experts}, " # noqa
f"kv_cache_dtype={self.cache_config.cache_dtype}, "

View File

@@ -112,6 +112,7 @@ from vllm.v1.sample.logits_processor import LogitsProcessor
from vllm.version import __version__ as VLLM_VERSION
if TYPE_CHECKING:
from vllm.config.quantization import OnlineQuantizationConfigArgs
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.model_loader import LoadFormats
from vllm.usage.usage_lib import UsageContext
@@ -483,6 +484,7 @@ class EngineArgs:
hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
tokenizer_revision: str | None = ModelConfig.tokenizer_revision
quantization: QuantizationMethods | str | None = ModelConfig.quantization
quantization_config: "dict[str, Any] | OnlineQuantizationConfigArgs | None" = None
allow_deprecated_quantization: bool = ModelConfig.allow_deprecated_quantization
enforce_eager: bool = ModelConfig.enforce_eager
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
@@ -661,6 +663,12 @@ class EngineArgs:
if isinstance(self.ir_op_priority, dict):
self.ir_op_priority = IrOpPriorityConfig(**self.ir_op_priority)
from vllm.config.quantization import resolve_online_quant_config
self.quantization_config = resolve_online_quant_config(
self.quantization, self.quantization_config
)
# Setup plugins
from vllm.plugins import load_general_plugins
@@ -1431,6 +1439,7 @@ class EngineArgs:
tokenizer_revision=self.tokenizer_revision,
max_model_len=self.max_model_len,
quantization=self.quantization,
quantization_config=self.quantization_config,
allow_deprecated_quantization=self.allow_deprecated_quantization,
enforce_eager=self.enforce_eager,
enable_return_routed_experts=self.enable_return_routed_experts,

View File

@@ -34,6 +34,9 @@ from vllm.config.model import (
RunnerOption,
TokenizerMode,
)
from vllm.config.quantization import (
OnlineQuantizationConfigArgs,
)
from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest,
WeightTransferUpdateRequest,
@@ -247,6 +250,9 @@ class LLM:
attention_config: dict[str, Any] | AttentionConfig | None = None,
kv_cache_memory_bytes: int | None = None,
compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
quantization_config: dict[str, Any]
| OnlineQuantizationConfigArgs
| None = None,
logits_processors: list[str | type[LogitsProcessor]] | None = None,
**kwargs: Any,
) -> None:
@@ -367,6 +373,7 @@ class LLM:
profiler_config=profiler_config_instance,
attention_config=attention_config_instance,
compilation_config=compilation_config_instance,
quantization_config=quantization_config,
logits_processors=logits_processors,
**kwargs,
)

View File

@@ -33,6 +33,13 @@ QuantizationMethods = Literal[
"mxfp8",
"petit_nvfp4",
"cpu_awq",
"online",
# Below are values of the OnlineQuantScheme enum, specified as strings to
# avoid circular import issues. This is here to provide a shortcut where
# the user can specify "LLM(..., quantization='fp8_per_tensor')" as
# shorthand for creating a more complicated online quant config object
"fp8_per_tensor",
"fp8_per_block",
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
@@ -103,6 +110,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
raise ValueError(f"Invalid quantization method: {quantization}")
# lazy import to avoid triggering `torch.compile` too early
from vllm.config.quantization import OnlineQuantScheme
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
from .awq import AWQConfig
@@ -129,6 +137,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config
from .mxfp8 import Mxfp8Config
from .online.base import OnlineQuantizationConfig
from .petit import PetitNvFp4Config
from .torchao import TorchAOConfig
@@ -157,7 +166,20 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"mxfp8": Mxfp8Config,
"petit_nvfp4": PetitNvFp4Config,
"cpu_awq": CPUAWQConfig,
"online": OnlineQuantizationConfig,
}
# Below are values of the OnlineQuantScheme enum. This is here to provide
# a shortcut where the user can specify
# "LLM(..., quantization='fp8_per_tensor')" as shorthand for creating a
# more complicated online quant config object
for scheme in OnlineQuantScheme:
assert scheme.value not in method_to_config, (
f"Online quant scheme {scheme.value!r} conflicts with an "
f"existing quantization method"
)
method_to_config[scheme.value] = OnlineQuantizationConfig
# Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)

View File

@@ -497,6 +497,8 @@ class Fp8LinearMethod(LinearMethodBase):
return self.fp8_linear.apply_weights(layer, x, bias)
# TODO(future PR): remove this class in favor of
# online/fp8.py::Fp8PerTensorOnlineLinearMethod
class Fp8OnlineLinearMethod(Fp8LinearMethod):
"""Online version of Fp8LinearMethod which loads a full precision checkpoint
and quantizes weights during loading."""
@@ -919,6 +921,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
# TODO(future PR): remove this class in favor of
# online/fp8.py::Fp8PerTensorOnlineMoEMethod
class Fp8OnlineMoEMethod(Fp8MoEMethod):
"""MoE method for online FP8 quantization.
Supports loading quantized FP16/BF16 model checkpoints with dynamic

View File

@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

View File

@@ -0,0 +1,116 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
from vllm.config.quantization import (
OnlineQuantizationConfigArgs,
OnlineQuantScheme,
)
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
)
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.linear import (
LinearBase,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
should_ignore_layer,
)
from vllm.model_executor.layers.quantization.online.fp8 import (
Fp8PerBlockOnlineLinearMethod,
Fp8PerBlockOnlineMoEMethod,
Fp8PerTensorOnlineLinearMethod,
Fp8PerTensorOnlineMoEMethod,
)
class OnlineQuantizationConfig(QuantizationConfig):
"""Model-level config class for online quantization (quantize fp16/bf16 weights
during model loading, without requiring a pre-quantized checkpoint)."""
def __init__(
self,
args: OnlineQuantizationConfigArgs,
) -> None:
super().__init__()
if (
args.global_scheme is None
and args.linear_scheme_override is None
and args.moe_scheme_override is None
):
raise ValueError(
"OnlineQuantizationConfig requires at least one of "
"global_scheme, linear_scheme_override, or "
"moe_scheme_override to be set."
)
self.args = args
self.quant_scheme = args.global_scheme
self.ignored_layers: list[str] = args.ignore
@classmethod
def get_name(cls) -> QuantizationMethods:
return "online"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
# Note: as more online quant schemes will be added, this
# value will become the minimum across all supported schemes.
return 75
@classmethod
def get_config_filenames(cls) -> list[str]:
return []
@classmethod
def from_config(cls, config: dict[str, Any]) -> "OnlineQuantizationConfig":
raise NotImplementedError(
"OnlineQuantizationConfig does not support loading from a "
"checkpoint config. Use quantization_config or "
"quantization='fp8_per_tensor'/'fp8_per_block' instead."
)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase):
if should_ignore_layer(
prefix,
ignore=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedLinearMethod()
linear_scheme = self.args.linear_scheme_override or self.args.global_scheme
if linear_scheme == OnlineQuantScheme.FP8_PER_BLOCK:
return Fp8PerBlockOnlineLinearMethod()
else:
return Fp8PerTensorOnlineLinearMethod()
elif isinstance(layer, FusedMoE):
if should_ignore_layer(
prefix,
ignore=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedFusedMoEMethod(layer.moe_config)
moe_scheme = self.args.moe_scheme_override or self.args.global_scheme
if moe_scheme == OnlineQuantScheme.FP8_PER_BLOCK:
return Fp8PerBlockOnlineMoEMethod(layer=layer)
else:
return Fp8PerTensorOnlineMoEMethod(layer=layer)
return None

View File

@@ -0,0 +1,632 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
import torch
from torch.nn import Module
if TYPE_CHECKING:
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.oracle.fp8 import Fp8MoeBackend
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.kernels.linear import init_fp8_linear_kernel
from vllm.model_executor.layers.fused_moe import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
select_fp8_moe_backend,
)
from vllm.model_executor.layers.linear import (
LinearMethodBase,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
maybe_post_process_fp8_weight_block,
process_fp8_weight_block_strategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8Static128BlockSym,
kFp8StaticTensorSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
cutlass_fp8_supported,
)
from vllm.model_executor.model_loader.reload.layerwise import (
initialize_online_processing,
)
from vllm.model_executor.parameter import ModelWeightParameter
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported, per_block_cast_to_fp8
# ---------------------------------------------------------------------------
# Online FP8 Linear Methods
# ---------------------------------------------------------------------------
class _Fp8OnlineLinearBase(LinearMethodBase):
"""Shared base for online FP8 linear methods. Loads fp16/bf16 checkpoint
weights onto meta device and materializes them just-in-time."""
uses_meta_device: bool = True
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
device="meta", # materialized and processed during loading
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
initialize_online_processing(layer)
class Fp8PerTensorOnlineLinearMethod(_Fp8OnlineLinearBase):
"""Online tensorwise FP8 linear quantization.
Loads fp16/bf16 weights and quantizes them per-tensor during loading."""
def __init__(self):
self.out_dtype = torch.get_default_dtype()
# Use per-token quantization for better perf if dynamic and cutlass
if cutlass_fp8_supported():
activation_quant_key = kFp8DynamicTokenSym
else:
activation_quant_key = kFp8DynamicTensorSym
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key,
weight_quant_key=kFp8StaticTensorSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
layer.input_scale = None
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
# Update layer with new values.
replace_parameter(layer, "weight", qweight.t().data)
replace_parameter(layer, "weight_scale", weight_scale.data)
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# if batch invariant mode is enabled, use BF16 dequant
if envs.VLLM_BATCH_INVARIANT:
weight_fp8 = layer.weight.to(torch.bfloat16)
weight_scale = layer.weight_scale.to(torch.bfloat16)
if weight_scale.numel() == 1:
# Per-tensor: simple scalar multiplication
weight_bf16 = weight_fp8 * weight_scale
else:
# Multiple scales (fused modules like QKV)
if (
weight_scale.dim() == 1
and weight_scale.shape[0] == weight_fp8.shape[0]
):
# Per-row scaling
weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
else:
# Fallback
weight_bf16 = weight_fp8 * weight_scale
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
return self.fp8_linear.apply_weights(layer, x, bias)
class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase):
"""Online blockwise FP8 linear quantization.
Loads fp16/bf16 weights and quantizes them per-block during loading."""
def __init__(self):
self.out_dtype = torch.get_default_dtype()
self.weight_block_size = [128, 128]
self.use_deep_gemm = is_deep_gemm_supported()
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
use_deep_gemm=self.use_deep_gemm,
)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
super().create_weights(
layer,
input_size_per_partition,
output_partition_sizes,
input_size,
output_size,
params_dtype,
**extra_weight_attrs,
)
layer.weight_block_size = self.weight_block_size
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
layer.input_scale = None
block_size = self.weight_block_size
qweight, weight_scale_inv = per_block_cast_to_fp8(
layer.weight, block_size=block_size, use_ue8m0=False
)
qweight, weight_scale_inv = process_fp8_weight_block_strategy(
qweight, weight_scale_inv
)
replace_parameter(layer, "weight", qweight.data)
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
maybe_post_process_fp8_weight_block(layer)
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.weight_block_size is not None
# Note: batch invariance already handled in the function below
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
bias=bias,
)
# ---------------------------------------------------------------------------
# Online FP8 MoE Methods
# ---------------------------------------------------------------------------
class _Fp8OnlineMoEBase(FusedMoEMethodBase):
"""Shared base for online FP8 MoE methods. Loads fp16/bf16 checkpoint
weights onto meta device and materializes them just-in-time."""
uses_meta_device: bool = True
# Declared here for mypy; actual values are set in __init__.
fp8_backend: "Fp8MoeBackend"
experts_cls: "type[mk.FusedMoEExperts] | None"
weight_scale_name: str
weight_block_size: list[int] | None
moe: "FusedMoEConfig"
is_monolithic: bool
moe_quant_config: "FusedMoEQuantConfig | None"
moe_kernel: "mk.FusedMoEKernel | None"
def __init__(
self,
*,
weight_block_size: list[int] | None,
layer: torch.nn.Module,
):
super().__init__(layer.moe_config)
self.weight_block_size = weight_block_size
self.block_quant: bool = self.weight_block_size is not None
self.weight_scale_name = (
"weight_scale_inv" if self.block_quant else "weight_scale"
)
# Set weight key and activation key for kernel compatibility
if self.block_quant:
weight_key = kFp8Static128BlockSym
activation_key = kFp8Dynamic128Sym
else:
weight_key = kFp8StaticTensorSym
activation_key = kFp8DynamicTensorSym
# Select Fp8 MoE backend
self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
config=self.moe,
weight_key=weight_key,
activation_key=activation_key,
allow_vllm_cutlass=False,
)
def create_weights(
self,
layer: Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
device="meta",
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
device="meta", # materialized and processed during loading
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# BIASES (for models like GPT-OSS that have biased MoE)
if self.moe.has_bias:
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
device="meta", # materialized and processed during loading
dtype=layer.orig_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
w2_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
device="meta", # materialized and processed during loading
dtype=layer.orig_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
layer.w13_input_scale = None
layer.w2_input_scale = None
initialize_online_processing(layer)
def _setup_kernel(
self,
layer: "FusedMoE",
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
w13_input_scale: torch.Tensor | None,
w2_input_scale: torch.Tensor | None,
) -> None:
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
)
# Shuffle weights to runtime format.
w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
fp8_backend=self.fp8_backend,
layer=layer,
w13=w13,
w2=w2,
w13_scale=w13_scale,
w2_scale=w2_scale,
w13_input_scale=w13_input_scale,
w2_input_scale=w2_input_scale,
)
# Replace parameters with updated versions. Note that this helper
# function ensures the replacement is compatible with RL weight reloads.
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w2_weight", w2)
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_kernel = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> "mk.FusedMoEPrepareAndFinalizeModular | None":
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel "
"initialization logic. This function should not be called."
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> "FusedMoEQuantConfig":
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
make_fp8_moe_quant_config,
)
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
a1_scale = layer.w13_input_scale
a2_scale = layer.w2_input_scale
quant_config = make_fp8_moe_quant_config(
fp8_backend=self.fp8_backend,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=self.weight_block_size,
)
# Inject biases into the quant config if the model has them
# (e.g. GPT-OSS biased MoE)
if quant_config is not None and self.moe.has_bias:
w13_bias = getattr(layer, "w13_bias", None)
w2_bias = getattr(layer, "w2_bias", None)
if w13_bias is not None:
quant_config._w1.bias = w13_bias
if w2_bias is not None:
quant_config._w2.bias = w2_bias
return quant_config
@property
def supports_eplb(self) -> bool:
return True
def apply_monolithic(
self,
layer: "FusedMoE",
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
)
def apply(
self,
layer: "FusedMoE",
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)
class Fp8PerTensorOnlineMoEMethod(_Fp8OnlineMoEBase):
"""Online tensorwise FP8 MoE quantization.
Loads fp16/bf16 weights and quantizes them per-tensor during loading."""
def __init__(
self,
*,
layer: torch.nn.Module,
):
super().__init__(
weight_block_size=None,
layer=layer,
)
def process_weights_after_loading(self, layer: Module) -> None:
# TODO(@ksayers): inplace fp8 quant kernel, initialize scales with ones
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# If checkpoint is fp16, quantize in place.
fp8_dtype = current_platform.fp8_dtype()
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
w13_scale = torch.ones(
layer.num_experts, device=w13.device, dtype=torch.float32
)
w2_scale = torch.ones(layer.num_experts, device=w2.device, dtype=torch.float32)
layer.w13_input_scale = None
layer.w2_input_scale = None
for expert in range(layer.local_num_experts):
w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
layer.w13_weight[expert, :, :]
)
w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant(
layer.w2_weight[expert, :, :]
)
# Shuffle weights to runtime format and setup kernel.
self._setup_kernel(
layer,
w13,
w2,
w13_scale,
w2_scale,
w13_input_scale=layer.w13_input_scale,
w2_input_scale=layer.w2_input_scale,
)
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True
class Fp8PerBlockOnlineMoEMethod(_Fp8OnlineMoEBase):
"""Online blockwise FP8 MoE quantization.
Loads fp16/bf16 weights and quantizes them per-block during loading."""
def __init__(
self,
*,
layer: torch.nn.Module,
):
super().__init__(
weight_block_size=[128, 128],
layer=layer,
)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
fp8_dtype = current_platform.fp8_dtype()
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
block_size = self.weight_block_size
assert block_size is not None
block_n, block_k = block_size
# Create block-shaped scales (computed here rather than in
# create_weights because online quant doesn't need them until now).
num_experts = layer.local_num_experts
_, w13_out, w13_in = layer.w13_weight.shape
_, w2_out, w2_in = layer.w2_weight.shape
w13_scale = torch.ones(
num_experts,
(w13_out + block_n - 1) // block_n,
(w13_in + block_k - 1) // block_k,
dtype=torch.float32,
device=w13.device,
)
w2_scale = torch.ones(
num_experts,
(w2_out + block_n - 1) // block_n,
(w2_in + block_k - 1) // block_k,
dtype=torch.float32,
device=w2.device,
)
for expert in range(num_experts):
w13[expert], w13_scale[expert] = per_block_cast_to_fp8(
layer.w13_weight[expert],
block_size=block_size,
use_ue8m0=False,
)
w2[expert], w2_scale[expert] = per_block_cast_to_fp8(
layer.w2_weight[expert],
block_size=block_size,
use_ue8m0=False,
)
layer.weight_block_size = block_size
# Shuffle weights to runtime format and setup kernel.
self._setup_kernel(
layer,
w13,
w2,
w13_scale,
w2_scale,
layer.w13_input_scale,
layer.w2_input_scale,
)
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True

View File

@@ -296,6 +296,13 @@ def get_quant_config(
)
if hf_quant_config is not None:
if model_config.quantization_config is not None:
raise ValueError(
"Setting `quantization_config` for online "
"quantization when the model checkpoint already "
"has a `quantization_config` is not supported"
)
# For modelopt_mixed, config.json's quantization_config may or may
# not contain the per-layer quantized_layers map. Newer checkpoints
# embed it directly; older ones keep it only in hf_quant_config.json.
@@ -319,6 +326,12 @@ def get_quant_config(
quantization_config_file = hf_overrides.get("quantization_config_file", None)
if quantization_config_file is not None:
if hasattr(quant_cls, "from_config_file"):
if model_config.quantization_config is not None:
raise ValueError(
"Setting `quantization_config` for online "
"quantization when the model checkpoint already "
"has a `quantization_config` is not supported"
)
return quant_cls.from_config_file(quantization_config_file)
else:
raise NotImplementedError(
@@ -329,6 +342,12 @@ def get_quant_config(
quantization_config_json = hf_overrides.get("quantization_config_dict_json", None)
if quantization_config_json is not None:
if hasattr(quant_cls, "from_config_dict_json"):
if model_config.quantization_config is not None:
raise ValueError(
"Setting `quantization_config` for online "
"quantization when the model checkpoint already "
"has a `quantization_config` is not supported"
)
return quant_cls.from_config_dict_json(quantization_config_json)
else:
raise NotImplementedError(
@@ -337,6 +356,19 @@ def get_quant_config(
f"{quant_cls}"
)
# Online quantization doesn't read from checkpoint configs — it quantizes
# fp16/bf16 weights on the fly during loading.
if model_config.quantization_config is not None:
from vllm.config.quantization import OnlineQuantizationConfigArgs
from vllm.model_executor.layers.quantization.online.base import (
OnlineQuantizationConfig,
)
assert isinstance(
model_config.quantization_config, OnlineQuantizationConfigArgs
)
return OnlineQuantizationConfig(args=model_config.quantization_config)
# Inflight BNB quantization
if model_config.quantization == "bitsandbytes":
return quant_cls.from_config({})