[Frontend] new online quantization frontend (#38138)

Signed-off-by: Vasiliy Kuznetsov <vasiliy@meta.com>
This commit is contained in:
Vasiliy Kuznetsov
2026-04-03 11:58:39 -04:00
committed by GitHub
parent 97f92c6b47
commit 7b1a7423be
13 changed files with 1205 additions and 0 deletions

View File

@@ -0,0 +1,179 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests online quantization."""
import pytest
import torch
from tests.quantization.utils import (
_test_online_quant_peak_mem_impl,
is_quant_method_supported,
)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.online.fp8 import (
Fp8PerBlockOnlineLinearMethod,
Fp8PerBlockOnlineMoEMethod,
Fp8PerTensorOnlineLinearMethod,
Fp8PerTensorOnlineMoEMethod,
)
from vllm.platforms import current_platform
@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
@pytest.mark.parametrize(
"quant_scheme,online_quant_args,expected_linear_cls,expected_moe_cls",
[
# simple case - quantization='fp8_per_tensor'
(
"fp8_per_tensor",
None,
Fp8PerTensorOnlineLinearMethod,
Fp8PerTensorOnlineMoEMethod,
),
# simple case - quantization='fp8_per_block'
(
"fp8_per_block",
None,
Fp8PerBlockOnlineLinearMethod,
Fp8PerBlockOnlineMoEMethod,
),
# quantization='online with linear_scheme_override and
# moe_scheme_override
(
"online",
{
"linear_scheme_override": "fp8_per_block",
"moe_scheme_override": "fp8_per_tensor",
},
Fp8PerBlockOnlineLinearMethod,
Fp8PerTensorOnlineMoEMethod,
),
# ignore with direct layer name
(
"fp8_per_tensor",
# qkv_proj is fused from q_proj/k_proj/v_proj, so currently the
# ignore regex must match the unfused shard names
# TODO(future PR): also make 're:.*qkv_proj.*' work
{"ignore": ["model.layers.1.self_attn.o_proj", "re:.*[qkv]_proj"]},
Fp8PerTensorOnlineLinearMethod,
Fp8PerTensorOnlineMoEMethod,
),
],
)
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
def test_online_quantization(
vllm_runner,
quant_scheme: str,
online_quant_args: dict | None,
expected_linear_cls,
expected_moe_cls,
use_rocm_aiter: bool,
monkeypatch,
) -> None:
"""
Tests that online quantization frontend configuration works -
selecting quant schemes, overriding quant schemes by type, ignoring
layers.
Does not test performance, peak memory usage, etc.
"""
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# `LLM.apply_model` requires pickling a function.
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
# a tiny model with both dense and MoE layers
model_name = "ibm-granite/granite-3.0-1b-a400m-base"
runner_kwargs = dict(
quantization=quant_scheme,
enforce_eager=True,
)
if online_quant_args is not None:
runner_kwargs["quantization_config"] = online_quant_args
with vllm_runner(
model_name,
**runner_kwargs,
) as llm:
def check_model(model):
# checks further down in the test case are hardcoded for this
# model
assert model_name == "ibm-granite/granite-3.0-1b-a400m-base"
o_proj = model.model.layers[0].self_attn.o_proj
moe = model.model.layers[0].block_sparse_moe.experts
# o_proj and moe in layer 0 are always quantized (never ignored)
# because of how we craft the test case inputs
assert isinstance(o_proj.quant_method, expected_linear_cls)
if moe is not None:
assert isinstance(moe.quant_method, expected_moe_cls)
if current_platform.is_cuda():
assert o_proj.weight.dtype == torch.float8_e4m3fn
elif current_platform.is_rocm():
assert o_proj.weight.dtype == current_platform.fp8_dtype()
else:
pytest.skip("Only runs on CUDA and ROCm.")
# Verify ignored layers are unquantized.
if isinstance(online_quant_args, dict) and "ignore" in online_quant_args:
# only .*1.self_attn_o_proj is skipped
for layer_idx in range(len(model.model.layers)):
o_proj = model.model.layers[layer_idx].self_attn.o_proj
if layer_idx == 1:
assert isinstance(o_proj.quant_method, UnquantizedLinearMethod)
else:
assert isinstance(o_proj.quant_method, expected_linear_cls)
# every .*self_attn.qkv_proj is skipped
for layer_idx in range(len(model.model.layers)):
qkv_proj = model.model.layers[layer_idx].self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, UnquantizedLinearMethod)
llm.apply_model(check_model)
outputs = llm.generate_greedy(["Hello my name is"], max_tokens=4)
print(outputs[0][1])
@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
def test_online_quant_peak_mem(
vllm_runner,
caplog_mp_spawn,
monkeypatch,
) -> None:
_test_online_quant_peak_mem_impl(
"fp8_per_tensor", vllm_runner, caplog_mp_spawn, monkeypatch
)
@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
def test_online_quant_load_format_dummy(
vllm_runner,
monkeypatch,
caplog,
) -> None:
with vllm_runner(
"ibm-granite/granite-3.0-1b-a400m-base",
quantization="fp8_per_tensor",
enforce_eager=True,
load_format="dummy",
) as llm:
outputs = llm.generate_greedy(["The future of AI is"], max_tokens=4)
print(outputs[0][1])

View File

@@ -1,6 +1,10 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import regex as re
from vllm.model_executor.layers.quantization import get_quantization_config from vllm.model_executor.layers.quantization import get_quantization_config
from vllm.platforms import current_platform from vllm.platforms import current_platform
@@ -21,3 +25,74 @@ def is_quant_method_supported(quant_method: str) -> bool:
min_capability = get_quantization_config(quant_method).get_min_capability() min_capability = get_quantization_config(quant_method).get_min_capability()
return capability.to_int() >= min_capability return capability.to_int() >= min_capability
def _test_online_quant_peak_mem_impl(
quantization_arg_value,
vllm_runner,
caplog_mp_spawn,
monkeypatch,
) -> None:
# Note: `allenai/OLMoE-1B-7B-0125-Instruct` was selected because:
# 1. it covers both Linear and MoE paths
# 2. it is already used by other tests in CI, so adding it here
# does not increase disk space for CI runners
# I really wanted to use `ibm-granite/granite-3.0-1b-a400m-base`
# which I think is the smallest MoE model in vLLM (2.5 GiB bf16,
# 1.3 GiB fp8), but could not as adding one more model makes CI
# run out of disk space.
model_name = "allenai/OLMoE-1B-7B-0125-Instruct"
# Force spawn to ensure caplog_mp_spawn works consistently
# (it relies on VLLM_LOGGING_CONFIG_PATH which spawn reads but fork ignores)
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
with (
caplog_mp_spawn(logging.DEBUG) as log_holder,
vllm_runner(
model_name,
quantization=quantization_arg_value,
enforce_eager=True,
) as llm,
):
outputs = llm.generate_greedy(["The future of AI is"], max_tokens=4)
print(outputs[0][1])
log_text = log_holder.text
# Parse memory usage from captured logs
model_memory_gib = None
peak_memory_gib = None
for line in log_text.splitlines():
if model_memory_gib is None:
match = re.search(r"Model loading took ([\d.]+) GiB memory", line)
if match:
model_memory_gib = float(match.group(1))
if peak_memory_gib is None:
match = re.search(
r"Peak GPU memory after loading weights: ([\d.]+) GiB", line
)
if match:
peak_memory_gib = float(match.group(1))
assert model_memory_gib is not None, "Could not find model loading memory log"
assert peak_memory_gib is not None, "Could not find peak memory log"
print(f"GPU memory used after loading weights: {model_memory_gib} GiB")
print(f"Peak GPU memory usage while loading weights: {peak_memory_gib} GiB")
# model specific, allenai/OLMoE-1B-7B-0125-Instruct fp8 online quant
# uses 6.65 GiB for weight loading (bf16 checkpoint is ~12.89 GiB)
expected_model_memory_gib = 6.7
# for allenai/OLMoE-1B-7B-0125-Instruct the number we see today is 9.06
# GiB, which is 1.36x above model_memory_gib. A slightly higher number is
# expected as when we load and quantize weights in a streaming fashion we
# need to have individual weights in bf16 + fp8 alive at the same time.
expected_peak_memory_gib = expected_model_memory_gib * 1.4
assert model_memory_gib < expected_model_memory_gib, (
f"{model_memory_gib=} higher than {expected_model_memory_gib}"
)
assert peak_memory_gib < expected_peak_memory_gib, (
f"{peak_memory_gib=} higher than {expected_peak_memory_gib}"
)

View File

@@ -21,6 +21,7 @@ from vllm.config.multimodal import (
MultiModalConfig, MultiModalConfig,
) )
from vllm.config.pooler import PoolerConfig from vllm.config.pooler import PoolerConfig
from vllm.config.quantization import OnlineQuantizationConfigArgs
from vllm.config.scheduler import RunnerType from vllm.config.scheduler import RunnerType
from vllm.config.utils import config, getattr_iter from vllm.config.utils import config, getattr_iter
from vllm.logger import init_logger from vllm.logger import init_logger
@@ -199,6 +200,10 @@ class ModelConfig:
`quantization_config` attribute in the model config file. If that is `quantization_config` attribute in the model config file. If that is
`None`, we assume the model weights are not quantized and use `dtype` to `None`, we assume the model weights are not quantized and use `dtype` to
determine the data type of the weights.""" determine the data type of the weights."""
quantization_config: dict[str, Any] | OnlineQuantizationConfigArgs | None = None
"""Arguments for online quantization.
Auto-created when `quantization` equals to one of the string values of
the `OnlineQuantScheme` enum."""
allow_deprecated_quantization: bool = False allow_deprecated_quantization: bool = False
"""Whether to allow deprecated quantization methods.""" """Whether to allow deprecated quantization methods."""
enforce_eager: bool = False enforce_eager: bool = False

121
vllm/config/quantization.py Normal file
View File

@@ -0,0 +1,121 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import Any
from pydantic import Field, field_validator
from vllm.config.utils import config
class OnlineQuantScheme(Enum):
"""Supported online quantization schemes."""
# fp8, weights and activations scaled per-tensor
FP8_PER_TENSOR = "fp8_per_tensor"
# fp8, activations scaled in blocks of 1x128 elements, weights scaled in
# blocks of 128x128 elements (popularized by DeepSeek)
FP8_PER_BLOCK = "fp8_per_block"
# TODO(future PRs): add more online quant schemes here: mxfp8, etc
@config
class OnlineQuantizationConfigArgs:
"""Configuration for online quantization.
Controls how ``OnlineQuantizationConfig`` is applied to a model.
At least one of ``global_scheme``, ``linear_scheme_override``, or
``moe_scheme_override`` must be set.
"""
global_scheme: OnlineQuantScheme | None = None
"""Quantization scheme applied to every supported layer."""
linear_scheme_override: OnlineQuantScheme | None = None
"""Quantization scheme override for ``LinearBase`` layers."""
moe_scheme_override: OnlineQuantScheme | None = None
"""Quantization scheme override for ``FusedMoE`` layers."""
ignore: list[str] = Field(default_factory=list)
"""Layers to skip quantization for. Supports exact names and regex
patterns with ``re:`` prefix (e.g. ``re:.*attn.*``), consistent with
compressed_tensors layer skipping."""
@field_validator(
"global_scheme", "linear_scheme_override", "moe_scheme_override", mode="before"
)
@classmethod
def _coerce_scheme(
cls, v: str | OnlineQuantScheme | None
) -> OnlineQuantScheme | None:
if isinstance(v, str):
return OnlineQuantScheme(v)
return v
def resolve_online_quant_config(
quantization: str | None,
quantization_config: dict[str, Any] | OnlineQuantizationConfigArgs | None,
) -> OnlineQuantizationConfigArgs | None:
"""Resolve online quant scheme shorthand into a quantization config.
If ``quantization`` is an online quant scheme (e.g. ``'fp8_per_tensor'``),
ensures ``quantization_config`` has a matching ``global_scheme`` and casts
it to :class:`OnlineQuantizationConfigArgs` if needed.
"""
online_quant_values = {s.value for s in OnlineQuantScheme}
valid_quantization_values = online_quant_values | {"online"}
if quantization not in valid_quantization_values:
if quantization_config is not None:
raise ValueError(
f"quantization_config is only supported when quantization "
f"is one of {sorted(valid_quantization_values)}, "
f"got quantization={quantization!r}"
)
return None
if quantization in online_quant_values:
scheme = OnlineQuantScheme(quantization)
if quantization_config is None:
quantization_config = {
"global_scheme": scheme.value,
}
elif isinstance(quantization_config, OnlineQuantizationConfigArgs):
if quantization_config.global_scheme is None:
quantization_config.global_scheme = scheme
elif quantization_config.global_scheme != scheme:
raise ValueError(
f"quantization={quantization!r} conflicts with "
f"quantization_config.global_scheme="
f"{quantization_config.global_scheme.value!r}. "
f"These must match when both are specified."
)
elif isinstance(quantization_config, dict):
existing = quantization_config.get("global_scheme")
if existing is None:
quantization_config["global_scheme"] = scheme.value
else:
# Coerce to enum for comparison
existing_scheme = (
OnlineQuantScheme(existing)
if isinstance(existing, str)
else existing
)
if existing_scheme != scheme:
raise ValueError(
f"quantization={quantization!r} conflicts "
f"with quantization_config"
f"['global_scheme']={existing!r}. "
f"These must match when both are specified."
)
# Cast dict to OnlineQuantizationConfigArgs
if isinstance(quantization_config, dict):
quantization_config = OnlineQuantizationConfigArgs(**quantization_config)
return quantization_config

View File

@@ -1713,6 +1713,7 @@ class VllmConfig:
f"dcp_comm_backend={self.parallel_config.dcp_comm_backend}, " # noqa f"dcp_comm_backend={self.parallel_config.dcp_comm_backend}, " # noqa
f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa
f"quantization={self.model_config.quantization}, " f"quantization={self.model_config.quantization}, "
f"quantization_config={self.model_config.quantization_config}, " # noqa
f"enforce_eager={self.model_config.enforce_eager}, " f"enforce_eager={self.model_config.enforce_eager}, "
f"enable_return_routed_experts={self.model_config.enable_return_routed_experts}, " # noqa f"enable_return_routed_experts={self.model_config.enable_return_routed_experts}, " # noqa
f"kv_cache_dtype={self.cache_config.cache_dtype}, " f"kv_cache_dtype={self.cache_config.cache_dtype}, "

View File

@@ -112,6 +112,7 @@ from vllm.v1.sample.logits_processor import LogitsProcessor
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config.quantization import OnlineQuantizationConfigArgs
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.model_loader import LoadFormats from vllm.model_executor.model_loader import LoadFormats
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
@@ -483,6 +484,7 @@ class EngineArgs:
hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides") hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
tokenizer_revision: str | None = ModelConfig.tokenizer_revision tokenizer_revision: str | None = ModelConfig.tokenizer_revision
quantization: QuantizationMethods | str | None = ModelConfig.quantization quantization: QuantizationMethods | str | None = ModelConfig.quantization
quantization_config: "dict[str, Any] | OnlineQuantizationConfigArgs | None" = None
allow_deprecated_quantization: bool = ModelConfig.allow_deprecated_quantization allow_deprecated_quantization: bool = ModelConfig.allow_deprecated_quantization
enforce_eager: bool = ModelConfig.enforce_eager enforce_eager: bool = ModelConfig.enforce_eager
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
@@ -661,6 +663,12 @@ class EngineArgs:
if isinstance(self.ir_op_priority, dict): if isinstance(self.ir_op_priority, dict):
self.ir_op_priority = IrOpPriorityConfig(**self.ir_op_priority) self.ir_op_priority = IrOpPriorityConfig(**self.ir_op_priority)
from vllm.config.quantization import resolve_online_quant_config
self.quantization_config = resolve_online_quant_config(
self.quantization, self.quantization_config
)
# Setup plugins # Setup plugins
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins
@@ -1431,6 +1439,7 @@ class EngineArgs:
tokenizer_revision=self.tokenizer_revision, tokenizer_revision=self.tokenizer_revision,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
quantization=self.quantization, quantization=self.quantization,
quantization_config=self.quantization_config,
allow_deprecated_quantization=self.allow_deprecated_quantization, allow_deprecated_quantization=self.allow_deprecated_quantization,
enforce_eager=self.enforce_eager, enforce_eager=self.enforce_eager,
enable_return_routed_experts=self.enable_return_routed_experts, enable_return_routed_experts=self.enable_return_routed_experts,

View File

@@ -34,6 +34,9 @@ from vllm.config.model import (
RunnerOption, RunnerOption,
TokenizerMode, TokenizerMode,
) )
from vllm.config.quantization import (
OnlineQuantizationConfigArgs,
)
from vllm.distributed.weight_transfer.base import ( from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest, WeightTransferInitRequest,
WeightTransferUpdateRequest, WeightTransferUpdateRequest,
@@ -247,6 +250,9 @@ class LLM:
attention_config: dict[str, Any] | AttentionConfig | None = None, attention_config: dict[str, Any] | AttentionConfig | None = None,
kv_cache_memory_bytes: int | None = None, kv_cache_memory_bytes: int | None = None,
compilation_config: int | dict[str, Any] | CompilationConfig | None = None, compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
quantization_config: dict[str, Any]
| OnlineQuantizationConfigArgs
| None = None,
logits_processors: list[str | type[LogitsProcessor]] | None = None, logits_processors: list[str | type[LogitsProcessor]] | None = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
@@ -367,6 +373,7 @@ class LLM:
profiler_config=profiler_config_instance, profiler_config=profiler_config_instance,
attention_config=attention_config_instance, attention_config=attention_config_instance,
compilation_config=compilation_config_instance, compilation_config=compilation_config_instance,
quantization_config=quantization_config,
logits_processors=logits_processors, logits_processors=logits_processors,
**kwargs, **kwargs,
) )

View File

@@ -33,6 +33,13 @@ QuantizationMethods = Literal[
"mxfp8", "mxfp8",
"petit_nvfp4", "petit_nvfp4",
"cpu_awq", "cpu_awq",
"online",
# Below are values of the OnlineQuantScheme enum, specified as strings to
# avoid circular import issues. This is here to provide a shortcut where
# the user can specify "LLM(..., quantization='fp8_per_tensor')" as
# shorthand for creating a more complicated online quant config object
"fp8_per_tensor",
"fp8_per_block",
] ]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
@@ -103,6 +110,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
raise ValueError(f"Invalid quantization method: {quantization}") raise ValueError(f"Invalid quantization method: {quantization}")
# lazy import to avoid triggering `torch.compile` too early # lazy import to avoid triggering `torch.compile` too early
from vllm.config.quantization import OnlineQuantScheme
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
from .awq import AWQConfig from .awq import AWQConfig
@@ -129,6 +137,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .moe_wna16 import MoeWNA16Config from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config from .mxfp4 import Mxfp4Config
from .mxfp8 import Mxfp8Config from .mxfp8 import Mxfp8Config
from .online.base import OnlineQuantizationConfig
from .petit import PetitNvFp4Config from .petit import PetitNvFp4Config
from .torchao import TorchAOConfig from .torchao import TorchAOConfig
@@ -157,7 +166,20 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"mxfp8": Mxfp8Config, "mxfp8": Mxfp8Config,
"petit_nvfp4": PetitNvFp4Config, "petit_nvfp4": PetitNvFp4Config,
"cpu_awq": CPUAWQConfig, "cpu_awq": CPUAWQConfig,
"online": OnlineQuantizationConfig,
} }
# Below are values of the OnlineQuantScheme enum. This is here to provide
# a shortcut where the user can specify
# "LLM(..., quantization='fp8_per_tensor')" as shorthand for creating a
# more complicated online quant config object
for scheme in OnlineQuantScheme:
assert scheme.value not in method_to_config, (
f"Online quant scheme {scheme.value!r} conflicts with an "
f"existing quantization method"
)
method_to_config[scheme.value] = OnlineQuantizationConfig
# Update the `method_to_config` with customized quantization methods. # Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)

View File

@@ -497,6 +497,8 @@ class Fp8LinearMethod(LinearMethodBase):
return self.fp8_linear.apply_weights(layer, x, bias) return self.fp8_linear.apply_weights(layer, x, bias)
# TODO(future PR): remove this class in favor of
# online/fp8.py::Fp8PerTensorOnlineLinearMethod
class Fp8OnlineLinearMethod(Fp8LinearMethod): class Fp8OnlineLinearMethod(Fp8LinearMethod):
"""Online version of Fp8LinearMethod which loads a full precision checkpoint """Online version of Fp8LinearMethod which loads a full precision checkpoint
and quantizes weights during loading.""" and quantizes weights during loading."""
@@ -919,6 +921,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
) )
# TODO(future PR): remove this class in favor of
# online/fp8.py::Fp8PerTensorOnlineMoEMethod
class Fp8OnlineMoEMethod(Fp8MoEMethod): class Fp8OnlineMoEMethod(Fp8MoEMethod):
"""MoE method for online FP8 quantization. """MoE method for online FP8 quantization.
Supports loading quantized FP16/BF16 model checkpoints with dynamic Supports loading quantized FP16/BF16 model checkpoints with dynamic

View File

@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

View File

@@ -0,0 +1,116 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
from vllm.config.quantization import (
OnlineQuantizationConfigArgs,
OnlineQuantScheme,
)
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
)
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.linear import (
LinearBase,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
should_ignore_layer,
)
from vllm.model_executor.layers.quantization.online.fp8 import (
Fp8PerBlockOnlineLinearMethod,
Fp8PerBlockOnlineMoEMethod,
Fp8PerTensorOnlineLinearMethod,
Fp8PerTensorOnlineMoEMethod,
)
class OnlineQuantizationConfig(QuantizationConfig):
"""Model-level config class for online quantization (quantize fp16/bf16 weights
during model loading, without requiring a pre-quantized checkpoint)."""
def __init__(
self,
args: OnlineQuantizationConfigArgs,
) -> None:
super().__init__()
if (
args.global_scheme is None
and args.linear_scheme_override is None
and args.moe_scheme_override is None
):
raise ValueError(
"OnlineQuantizationConfig requires at least one of "
"global_scheme, linear_scheme_override, or "
"moe_scheme_override to be set."
)
self.args = args
self.quant_scheme = args.global_scheme
self.ignored_layers: list[str] = args.ignore
@classmethod
def get_name(cls) -> QuantizationMethods:
return "online"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
# Note: as more online quant schemes will be added, this
# value will become the minimum across all supported schemes.
return 75
@classmethod
def get_config_filenames(cls) -> list[str]:
return []
@classmethod
def from_config(cls, config: dict[str, Any]) -> "OnlineQuantizationConfig":
raise NotImplementedError(
"OnlineQuantizationConfig does not support loading from a "
"checkpoint config. Use quantization_config or "
"quantization='fp8_per_tensor'/'fp8_per_block' instead."
)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase):
if should_ignore_layer(
prefix,
ignore=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedLinearMethod()
linear_scheme = self.args.linear_scheme_override or self.args.global_scheme
if linear_scheme == OnlineQuantScheme.FP8_PER_BLOCK:
return Fp8PerBlockOnlineLinearMethod()
else:
return Fp8PerTensorOnlineLinearMethod()
elif isinstance(layer, FusedMoE):
if should_ignore_layer(
prefix,
ignore=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedFusedMoEMethod(layer.moe_config)
moe_scheme = self.args.moe_scheme_override or self.args.global_scheme
if moe_scheme == OnlineQuantScheme.FP8_PER_BLOCK:
return Fp8PerBlockOnlineMoEMethod(layer=layer)
else:
return Fp8PerTensorOnlineMoEMethod(layer=layer)
return None

View File

@@ -0,0 +1,632 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
import torch
from torch.nn import Module
if TYPE_CHECKING:
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.oracle.fp8 import Fp8MoeBackend
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.kernels.linear import init_fp8_linear_kernel
from vllm.model_executor.layers.fused_moe import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
select_fp8_moe_backend,
)
from vllm.model_executor.layers.linear import (
LinearMethodBase,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
maybe_post_process_fp8_weight_block,
process_fp8_weight_block_strategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8Static128BlockSym,
kFp8StaticTensorSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
cutlass_fp8_supported,
)
from vllm.model_executor.model_loader.reload.layerwise import (
initialize_online_processing,
)
from vllm.model_executor.parameter import ModelWeightParameter
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported, per_block_cast_to_fp8
# ---------------------------------------------------------------------------
# Online FP8 Linear Methods
# ---------------------------------------------------------------------------
class _Fp8OnlineLinearBase(LinearMethodBase):
"""Shared base for online FP8 linear methods. Loads fp16/bf16 checkpoint
weights onto meta device and materializes them just-in-time."""
uses_meta_device: bool = True
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
device="meta", # materialized and processed during loading
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
initialize_online_processing(layer)
class Fp8PerTensorOnlineLinearMethod(_Fp8OnlineLinearBase):
"""Online tensorwise FP8 linear quantization.
Loads fp16/bf16 weights and quantizes them per-tensor during loading."""
def __init__(self):
self.out_dtype = torch.get_default_dtype()
# Use per-token quantization for better perf if dynamic and cutlass
if cutlass_fp8_supported():
activation_quant_key = kFp8DynamicTokenSym
else:
activation_quant_key = kFp8DynamicTensorSym
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key,
weight_quant_key=kFp8StaticTensorSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
layer.input_scale = None
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
# Update layer with new values.
replace_parameter(layer, "weight", qweight.t().data)
replace_parameter(layer, "weight_scale", weight_scale.data)
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# if batch invariant mode is enabled, use BF16 dequant
if envs.VLLM_BATCH_INVARIANT:
weight_fp8 = layer.weight.to(torch.bfloat16)
weight_scale = layer.weight_scale.to(torch.bfloat16)
if weight_scale.numel() == 1:
# Per-tensor: simple scalar multiplication
weight_bf16 = weight_fp8 * weight_scale
else:
# Multiple scales (fused modules like QKV)
if (
weight_scale.dim() == 1
and weight_scale.shape[0] == weight_fp8.shape[0]
):
# Per-row scaling
weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
else:
# Fallback
weight_bf16 = weight_fp8 * weight_scale
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
return self.fp8_linear.apply_weights(layer, x, bias)
class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase):
"""Online blockwise FP8 linear quantization.
Loads fp16/bf16 weights and quantizes them per-block during loading."""
def __init__(self):
self.out_dtype = torch.get_default_dtype()
self.weight_block_size = [128, 128]
self.use_deep_gemm = is_deep_gemm_supported()
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
use_deep_gemm=self.use_deep_gemm,
)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
super().create_weights(
layer,
input_size_per_partition,
output_partition_sizes,
input_size,
output_size,
params_dtype,
**extra_weight_attrs,
)
layer.weight_block_size = self.weight_block_size
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
layer.input_scale = None
block_size = self.weight_block_size
qweight, weight_scale_inv = per_block_cast_to_fp8(
layer.weight, block_size=block_size, use_ue8m0=False
)
qweight, weight_scale_inv = process_fp8_weight_block_strategy(
qweight, weight_scale_inv
)
replace_parameter(layer, "weight", qweight.data)
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
maybe_post_process_fp8_weight_block(layer)
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.weight_block_size is not None
# Note: batch invariance already handled in the function below
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
bias=bias,
)
# ---------------------------------------------------------------------------
# Online FP8 MoE Methods
# ---------------------------------------------------------------------------
class _Fp8OnlineMoEBase(FusedMoEMethodBase):
"""Shared base for online FP8 MoE methods. Loads fp16/bf16 checkpoint
weights onto meta device and materializes them just-in-time."""
uses_meta_device: bool = True
# Declared here for mypy; actual values are set in __init__.
fp8_backend: "Fp8MoeBackend"
experts_cls: "type[mk.FusedMoEExperts] | None"
weight_scale_name: str
weight_block_size: list[int] | None
moe: "FusedMoEConfig"
is_monolithic: bool
moe_quant_config: "FusedMoEQuantConfig | None"
moe_kernel: "mk.FusedMoEKernel | None"
def __init__(
self,
*,
weight_block_size: list[int] | None,
layer: torch.nn.Module,
):
super().__init__(layer.moe_config)
self.weight_block_size = weight_block_size
self.block_quant: bool = self.weight_block_size is not None
self.weight_scale_name = (
"weight_scale_inv" if self.block_quant else "weight_scale"
)
# Set weight key and activation key for kernel compatibility
if self.block_quant:
weight_key = kFp8Static128BlockSym
activation_key = kFp8Dynamic128Sym
else:
weight_key = kFp8StaticTensorSym
activation_key = kFp8DynamicTensorSym
# Select Fp8 MoE backend
self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
config=self.moe,
weight_key=weight_key,
activation_key=activation_key,
allow_vllm_cutlass=False,
)
def create_weights(
self,
layer: Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
device="meta",
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
device="meta", # materialized and processed during loading
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# BIASES (for models like GPT-OSS that have biased MoE)
if self.moe.has_bias:
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
device="meta", # materialized and processed during loading
dtype=layer.orig_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
w2_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
device="meta", # materialized and processed during loading
dtype=layer.orig_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
layer.w13_input_scale = None
layer.w2_input_scale = None
initialize_online_processing(layer)
def _setup_kernel(
self,
layer: "FusedMoE",
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
w13_input_scale: torch.Tensor | None,
w2_input_scale: torch.Tensor | None,
) -> None:
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
)
# Shuffle weights to runtime format.
w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
fp8_backend=self.fp8_backend,
layer=layer,
w13=w13,
w2=w2,
w13_scale=w13_scale,
w2_scale=w2_scale,
w13_input_scale=w13_input_scale,
w2_input_scale=w2_input_scale,
)
# Replace parameters with updated versions. Note that this helper
# function ensures the replacement is compatible with RL weight reloads.
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w2_weight", w2)
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_kernel = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> "mk.FusedMoEPrepareAndFinalizeModular | None":
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel "
"initialization logic. This function should not be called."
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> "FusedMoEQuantConfig":
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
make_fp8_moe_quant_config,
)
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
a1_scale = layer.w13_input_scale
a2_scale = layer.w2_input_scale
quant_config = make_fp8_moe_quant_config(
fp8_backend=self.fp8_backend,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=self.weight_block_size,
)
# Inject biases into the quant config if the model has them
# (e.g. GPT-OSS biased MoE)
if quant_config is not None and self.moe.has_bias:
w13_bias = getattr(layer, "w13_bias", None)
w2_bias = getattr(layer, "w2_bias", None)
if w13_bias is not None:
quant_config._w1.bias = w13_bias
if w2_bias is not None:
quant_config._w2.bias = w2_bias
return quant_config
@property
def supports_eplb(self) -> bool:
return True
def apply_monolithic(
self,
layer: "FusedMoE",
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
)
def apply(
self,
layer: "FusedMoE",
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)
class Fp8PerTensorOnlineMoEMethod(_Fp8OnlineMoEBase):
"""Online tensorwise FP8 MoE quantization.
Loads fp16/bf16 weights and quantizes them per-tensor during loading."""
def __init__(
self,
*,
layer: torch.nn.Module,
):
super().__init__(
weight_block_size=None,
layer=layer,
)
def process_weights_after_loading(self, layer: Module) -> None:
# TODO(@ksayers): inplace fp8 quant kernel, initialize scales with ones
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# If checkpoint is fp16, quantize in place.
fp8_dtype = current_platform.fp8_dtype()
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
w13_scale = torch.ones(
layer.num_experts, device=w13.device, dtype=torch.float32
)
w2_scale = torch.ones(layer.num_experts, device=w2.device, dtype=torch.float32)
layer.w13_input_scale = None
layer.w2_input_scale = None
for expert in range(layer.local_num_experts):
w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
layer.w13_weight[expert, :, :]
)
w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant(
layer.w2_weight[expert, :, :]
)
# Shuffle weights to runtime format and setup kernel.
self._setup_kernel(
layer,
w13,
w2,
w13_scale,
w2_scale,
w13_input_scale=layer.w13_input_scale,
w2_input_scale=layer.w2_input_scale,
)
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True
class Fp8PerBlockOnlineMoEMethod(_Fp8OnlineMoEBase):
"""Online blockwise FP8 MoE quantization.
Loads fp16/bf16 weights and quantizes them per-block during loading."""
def __init__(
self,
*,
layer: torch.nn.Module,
):
super().__init__(
weight_block_size=[128, 128],
layer=layer,
)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
fp8_dtype = current_platform.fp8_dtype()
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
block_size = self.weight_block_size
assert block_size is not None
block_n, block_k = block_size
# Create block-shaped scales (computed here rather than in
# create_weights because online quant doesn't need them until now).
num_experts = layer.local_num_experts
_, w13_out, w13_in = layer.w13_weight.shape
_, w2_out, w2_in = layer.w2_weight.shape
w13_scale = torch.ones(
num_experts,
(w13_out + block_n - 1) // block_n,
(w13_in + block_k - 1) // block_k,
dtype=torch.float32,
device=w13.device,
)
w2_scale = torch.ones(
num_experts,
(w2_out + block_n - 1) // block_n,
(w2_in + block_k - 1) // block_k,
dtype=torch.float32,
device=w2.device,
)
for expert in range(num_experts):
w13[expert], w13_scale[expert] = per_block_cast_to_fp8(
layer.w13_weight[expert],
block_size=block_size,
use_ue8m0=False,
)
w2[expert], w2_scale[expert] = per_block_cast_to_fp8(
layer.w2_weight[expert],
block_size=block_size,
use_ue8m0=False,
)
layer.weight_block_size = block_size
# Shuffle weights to runtime format and setup kernel.
self._setup_kernel(
layer,
w13,
w2,
w13_scale,
w2_scale,
layer.w13_input_scale,
layer.w2_input_scale,
)
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True

View File

@@ -296,6 +296,13 @@ def get_quant_config(
) )
if hf_quant_config is not None: if hf_quant_config is not None:
if model_config.quantization_config is not None:
raise ValueError(
"Setting `quantization_config` for online "
"quantization when the model checkpoint already "
"has a `quantization_config` is not supported"
)
# For modelopt_mixed, config.json's quantization_config may or may # For modelopt_mixed, config.json's quantization_config may or may
# not contain the per-layer quantized_layers map. Newer checkpoints # not contain the per-layer quantized_layers map. Newer checkpoints
# embed it directly; older ones keep it only in hf_quant_config.json. # embed it directly; older ones keep it only in hf_quant_config.json.
@@ -319,6 +326,12 @@ def get_quant_config(
quantization_config_file = hf_overrides.get("quantization_config_file", None) quantization_config_file = hf_overrides.get("quantization_config_file", None)
if quantization_config_file is not None: if quantization_config_file is not None:
if hasattr(quant_cls, "from_config_file"): if hasattr(quant_cls, "from_config_file"):
if model_config.quantization_config is not None:
raise ValueError(
"Setting `quantization_config` for online "
"quantization when the model checkpoint already "
"has a `quantization_config` is not supported"
)
return quant_cls.from_config_file(quantization_config_file) return quant_cls.from_config_file(quantization_config_file)
else: else:
raise NotImplementedError( raise NotImplementedError(
@@ -329,6 +342,12 @@ def get_quant_config(
quantization_config_json = hf_overrides.get("quantization_config_dict_json", None) quantization_config_json = hf_overrides.get("quantization_config_dict_json", None)
if quantization_config_json is not None: if quantization_config_json is not None:
if hasattr(quant_cls, "from_config_dict_json"): if hasattr(quant_cls, "from_config_dict_json"):
if model_config.quantization_config is not None:
raise ValueError(
"Setting `quantization_config` for online "
"quantization when the model checkpoint already "
"has a `quantization_config` is not supported"
)
return quant_cls.from_config_dict_json(quantization_config_json) return quant_cls.from_config_dict_json(quantization_config_json)
else: else:
raise NotImplementedError( raise NotImplementedError(
@@ -337,6 +356,19 @@ def get_quant_config(
f"{quant_cls}" f"{quant_cls}"
) )
# Online quantization doesn't read from checkpoint configs — it quantizes
# fp16/bf16 weights on the fly during loading.
if model_config.quantization_config is not None:
from vllm.config.quantization import OnlineQuantizationConfigArgs
from vllm.model_executor.layers.quantization.online.base import (
OnlineQuantizationConfig,
)
assert isinstance(
model_config.quantization_config, OnlineQuantizationConfigArgs
)
return OnlineQuantizationConfig(args=model_config.quantization_config)
# Inflight BNB quantization # Inflight BNB quantization
if model_config.quantization == "bitsandbytes": if model_config.quantization == "bitsandbytes":
return quant_cls.from_config({}) return quant_cls.from_config({})