From 7b1a7423bea1705bd51d838e34bef99e8a01cbbd Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Fri, 3 Apr 2026 11:58:39 -0400 Subject: [PATCH] [Frontend] new online quantization frontend (#38138) Signed-off-by: Vasiliy Kuznetsov --- tests/quantization/test_online.py | 179 +++++ tests/quantization/utils.py | 75 +++ vllm/config/model.py | 5 + vllm/config/quantization.py | 121 ++++ vllm/config/vllm.py | 1 + vllm/engine/arg_utils.py | 9 + vllm/entrypoints/llm.py | 7 + .../layers/quantization/__init__.py | 22 + .../model_executor/layers/quantization/fp8.py | 4 + .../layers/quantization/online/__init__.py | 2 + .../layers/quantization/online/base.py | 116 ++++ .../layers/quantization/online/fp8.py | 632 ++++++++++++++++++ .../model_loader/weight_utils.py | 32 + 13 files changed, 1205 insertions(+) create mode 100644 tests/quantization/test_online.py create mode 100644 vllm/config/quantization.py create mode 100644 vllm/model_executor/layers/quantization/online/__init__.py create mode 100644 vllm/model_executor/layers/quantization/online/base.py create mode 100644 vllm/model_executor/layers/quantization/online/fp8.py diff --git a/tests/quantization/test_online.py b/tests/quantization/test_online.py new file mode 100644 index 000000000..89f9676e4 --- /dev/null +++ b/tests/quantization/test_online.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests online quantization.""" + +import pytest +import torch + +from tests.quantization.utils import ( + _test_online_quant_peak_mem_impl, + is_quant_method_supported, +) +from vllm.model_executor.layers.linear import UnquantizedLinearMethod +from vllm.model_executor.layers.quantization.online.fp8 import ( + Fp8PerBlockOnlineLinearMethod, + Fp8PerBlockOnlineMoEMethod, + Fp8PerTensorOnlineLinearMethod, + Fp8PerTensorOnlineMoEMethod, +) +from vllm.platforms import current_platform + + +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.", +) +@pytest.mark.parametrize( + "quant_scheme,online_quant_args,expected_linear_cls,expected_moe_cls", + [ + # simple case - quantization='fp8_per_tensor' + ( + "fp8_per_tensor", + None, + Fp8PerTensorOnlineLinearMethod, + Fp8PerTensorOnlineMoEMethod, + ), + # simple case - quantization='fp8_per_block' + ( + "fp8_per_block", + None, + Fp8PerBlockOnlineLinearMethod, + Fp8PerBlockOnlineMoEMethod, + ), + # quantization='online with linear_scheme_override and + # moe_scheme_override + ( + "online", + { + "linear_scheme_override": "fp8_per_block", + "moe_scheme_override": "fp8_per_tensor", + }, + Fp8PerBlockOnlineLinearMethod, + Fp8PerTensorOnlineMoEMethod, + ), + # ignore with direct layer name + ( + "fp8_per_tensor", + # qkv_proj is fused from q_proj/k_proj/v_proj, so currently the + # ignore regex must match the unfused shard names + # TODO(future PR): also make 're:.*qkv_proj.*' work + {"ignore": ["model.layers.1.self_attn.o_proj", "re:.*[qkv]_proj"]}, + Fp8PerTensorOnlineLinearMethod, + Fp8PerTensorOnlineMoEMethod, + ), + ], +) +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] +) +def test_online_quantization( + vllm_runner, + quant_scheme: str, + online_quant_args: dict | None, + expected_linear_cls, + expected_moe_cls, + use_rocm_aiter: bool, + monkeypatch, +) -> None: + """ + Tests that online quantization frontend configuration works - + selecting quant schemes, overriding quant schemes by type, ignoring + layers. + + Does not test performance, peak memory usage, etc. + """ + + if use_rocm_aiter: + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + + # `LLM.apply_model` requires pickling a function. + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + + # a tiny model with both dense and MoE layers + model_name = "ibm-granite/granite-3.0-1b-a400m-base" + + runner_kwargs = dict( + quantization=quant_scheme, + enforce_eager=True, + ) + if online_quant_args is not None: + runner_kwargs["quantization_config"] = online_quant_args + + with vllm_runner( + model_name, + **runner_kwargs, + ) as llm: + + def check_model(model): + # checks further down in the test case are hardcoded for this + # model + assert model_name == "ibm-granite/granite-3.0-1b-a400m-base" + + o_proj = model.model.layers[0].self_attn.o_proj + moe = model.model.layers[0].block_sparse_moe.experts + + # o_proj and moe in layer 0 are always quantized (never ignored) + # because of how we craft the test case inputs + assert isinstance(o_proj.quant_method, expected_linear_cls) + if moe is not None: + assert isinstance(moe.quant_method, expected_moe_cls) + + if current_platform.is_cuda(): + assert o_proj.weight.dtype == torch.float8_e4m3fn + elif current_platform.is_rocm(): + assert o_proj.weight.dtype == current_platform.fp8_dtype() + else: + pytest.skip("Only runs on CUDA and ROCm.") + + # Verify ignored layers are unquantized. + if isinstance(online_quant_args, dict) and "ignore" in online_quant_args: + # only .*1.self_attn_o_proj is skipped + for layer_idx in range(len(model.model.layers)): + o_proj = model.model.layers[layer_idx].self_attn.o_proj + if layer_idx == 1: + assert isinstance(o_proj.quant_method, UnquantizedLinearMethod) + else: + assert isinstance(o_proj.quant_method, expected_linear_cls) + + # every .*self_attn.qkv_proj is skipped + for layer_idx in range(len(model.model.layers)): + qkv_proj = model.model.layers[layer_idx].self_attn.qkv_proj + assert isinstance(qkv_proj.quant_method, UnquantizedLinearMethod) + + llm.apply_model(check_model) + + outputs = llm.generate_greedy(["Hello my name is"], max_tokens=4) + print(outputs[0][1]) + + +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.", +) +def test_online_quant_peak_mem( + vllm_runner, + caplog_mp_spawn, + monkeypatch, +) -> None: + _test_online_quant_peak_mem_impl( + "fp8_per_tensor", vllm_runner, caplog_mp_spawn, monkeypatch + ) + + +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.", +) +def test_online_quant_load_format_dummy( + vllm_runner, + monkeypatch, + caplog, +) -> None: + with vllm_runner( + "ibm-granite/granite-3.0-1b-a400m-base", + quantization="fp8_per_tensor", + enforce_eager=True, + load_format="dummy", + ) as llm: + outputs = llm.generate_greedy(["The future of AI is"], max_tokens=4) + print(outputs[0][1]) diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py index cf3da37b0..e5eea5772 100644 --- a/tests/quantization/utils.py +++ b/tests/quantization/utils.py @@ -1,6 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import logging + +import regex as re + from vllm.model_executor.layers.quantization import get_quantization_config from vllm.platforms import current_platform @@ -21,3 +25,74 @@ def is_quant_method_supported(quant_method: str) -> bool: min_capability = get_quantization_config(quant_method).get_min_capability() return capability.to_int() >= min_capability + + +def _test_online_quant_peak_mem_impl( + quantization_arg_value, + vllm_runner, + caplog_mp_spawn, + monkeypatch, +) -> None: + # Note: `allenai/OLMoE-1B-7B-0125-Instruct` was selected because: + # 1. it covers both Linear and MoE paths + # 2. it is already used by other tests in CI, so adding it here + # does not increase disk space for CI runners + # I really wanted to use `ibm-granite/granite-3.0-1b-a400m-base` + # which I think is the smallest MoE model in vLLM (2.5 GiB bf16, + # 1.3 GiB fp8), but could not as adding one more model makes CI + # run out of disk space. + model_name = "allenai/OLMoE-1B-7B-0125-Instruct" + + # Force spawn to ensure caplog_mp_spawn works consistently + # (it relies on VLLM_LOGGING_CONFIG_PATH which spawn reads but fork ignores) + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + + with ( + caplog_mp_spawn(logging.DEBUG) as log_holder, + vllm_runner( + model_name, + quantization=quantization_arg_value, + enforce_eager=True, + ) as llm, + ): + outputs = llm.generate_greedy(["The future of AI is"], max_tokens=4) + print(outputs[0][1]) + + log_text = log_holder.text + + # Parse memory usage from captured logs + model_memory_gib = None + peak_memory_gib = None + for line in log_text.splitlines(): + if model_memory_gib is None: + match = re.search(r"Model loading took ([\d.]+) GiB memory", line) + if match: + model_memory_gib = float(match.group(1)) + if peak_memory_gib is None: + match = re.search( + r"Peak GPU memory after loading weights: ([\d.]+) GiB", line + ) + if match: + peak_memory_gib = float(match.group(1)) + + assert model_memory_gib is not None, "Could not find model loading memory log" + assert peak_memory_gib is not None, "Could not find peak memory log" + print(f"GPU memory used after loading weights: {model_memory_gib} GiB") + print(f"Peak GPU memory usage while loading weights: {peak_memory_gib} GiB") + + # model specific, allenai/OLMoE-1B-7B-0125-Instruct fp8 online quant + # uses 6.65 GiB for weight loading (bf16 checkpoint is ~12.89 GiB) + expected_model_memory_gib = 6.7 + + # for allenai/OLMoE-1B-7B-0125-Instruct the number we see today is 9.06 + # GiB, which is 1.36x above model_memory_gib. A slightly higher number is + # expected as when we load and quantize weights in a streaming fashion we + # need to have individual weights in bf16 + fp8 alive at the same time. + expected_peak_memory_gib = expected_model_memory_gib * 1.4 + + assert model_memory_gib < expected_model_memory_gib, ( + f"{model_memory_gib=} higher than {expected_model_memory_gib}" + ) + assert peak_memory_gib < expected_peak_memory_gib, ( + f"{peak_memory_gib=} higher than {expected_peak_memory_gib}" + ) diff --git a/vllm/config/model.py b/vllm/config/model.py index 7bb3655f2..cea2e56ae 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -21,6 +21,7 @@ from vllm.config.multimodal import ( MultiModalConfig, ) from vllm.config.pooler import PoolerConfig +from vllm.config.quantization import OnlineQuantizationConfigArgs from vllm.config.scheduler import RunnerType from vllm.config.utils import config, getattr_iter from vllm.logger import init_logger @@ -199,6 +200,10 @@ class ModelConfig: `quantization_config` attribute in the model config file. If that is `None`, we assume the model weights are not quantized and use `dtype` to determine the data type of the weights.""" + quantization_config: dict[str, Any] | OnlineQuantizationConfigArgs | None = None + """Arguments for online quantization. + Auto-created when `quantization` equals to one of the string values of + the `OnlineQuantScheme` enum.""" allow_deprecated_quantization: bool = False """Whether to allow deprecated quantization methods.""" enforce_eager: bool = False diff --git a/vllm/config/quantization.py b/vllm/config/quantization.py new file mode 100644 index 000000000..1b7022380 --- /dev/null +++ b/vllm/config/quantization.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from enum import Enum +from typing import Any + +from pydantic import Field, field_validator + +from vllm.config.utils import config + + +class OnlineQuantScheme(Enum): + """Supported online quantization schemes.""" + + # fp8, weights and activations scaled per-tensor + FP8_PER_TENSOR = "fp8_per_tensor" + + # fp8, activations scaled in blocks of 1x128 elements, weights scaled in + # blocks of 128x128 elements (popularized by DeepSeek) + FP8_PER_BLOCK = "fp8_per_block" + + # TODO(future PRs): add more online quant schemes here: mxfp8, etc + + +@config +class OnlineQuantizationConfigArgs: + """Configuration for online quantization. + + Controls how ``OnlineQuantizationConfig`` is applied to a model. + At least one of ``global_scheme``, ``linear_scheme_override``, or + ``moe_scheme_override`` must be set. + """ + + global_scheme: OnlineQuantScheme | None = None + """Quantization scheme applied to every supported layer.""" + + linear_scheme_override: OnlineQuantScheme | None = None + """Quantization scheme override for ``LinearBase`` layers.""" + + moe_scheme_override: OnlineQuantScheme | None = None + """Quantization scheme override for ``FusedMoE`` layers.""" + + ignore: list[str] = Field(default_factory=list) + """Layers to skip quantization for. Supports exact names and regex + patterns with ``re:`` prefix (e.g. ``re:.*attn.*``), consistent with + compressed_tensors layer skipping.""" + + @field_validator( + "global_scheme", "linear_scheme_override", "moe_scheme_override", mode="before" + ) + @classmethod + def _coerce_scheme( + cls, v: str | OnlineQuantScheme | None + ) -> OnlineQuantScheme | None: + if isinstance(v, str): + return OnlineQuantScheme(v) + return v + + +def resolve_online_quant_config( + quantization: str | None, + quantization_config: dict[str, Any] | OnlineQuantizationConfigArgs | None, +) -> OnlineQuantizationConfigArgs | None: + """Resolve online quant scheme shorthand into a quantization config. + + If ``quantization`` is an online quant scheme (e.g. ``'fp8_per_tensor'``), + ensures ``quantization_config`` has a matching ``global_scheme`` and casts + it to :class:`OnlineQuantizationConfigArgs` if needed. + """ + online_quant_values = {s.value for s in OnlineQuantScheme} + valid_quantization_values = online_quant_values | {"online"} + if quantization not in valid_quantization_values: + if quantization_config is not None: + raise ValueError( + f"quantization_config is only supported when quantization " + f"is one of {sorted(valid_quantization_values)}, " + f"got quantization={quantization!r}" + ) + return None + + if quantization in online_quant_values: + scheme = OnlineQuantScheme(quantization) + + if quantization_config is None: + quantization_config = { + "global_scheme": scheme.value, + } + elif isinstance(quantization_config, OnlineQuantizationConfigArgs): + if quantization_config.global_scheme is None: + quantization_config.global_scheme = scheme + elif quantization_config.global_scheme != scheme: + raise ValueError( + f"quantization={quantization!r} conflicts with " + f"quantization_config.global_scheme=" + f"{quantization_config.global_scheme.value!r}. " + f"These must match when both are specified." + ) + elif isinstance(quantization_config, dict): + existing = quantization_config.get("global_scheme") + if existing is None: + quantization_config["global_scheme"] = scheme.value + else: + # Coerce to enum for comparison + existing_scheme = ( + OnlineQuantScheme(existing) + if isinstance(existing, str) + else existing + ) + if existing_scheme != scheme: + raise ValueError( + f"quantization={quantization!r} conflicts " + f"with quantization_config" + f"['global_scheme']={existing!r}. " + f"These must match when both are specified." + ) + + # Cast dict to OnlineQuantizationConfigArgs + if isinstance(quantization_config, dict): + quantization_config = OnlineQuantizationConfigArgs(**quantization_config) + + return quantization_config diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index fad3e0ed2..6551526d1 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1713,6 +1713,7 @@ class VllmConfig: f"dcp_comm_backend={self.parallel_config.dcp_comm_backend}, " # noqa f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa f"quantization={self.model_config.quantization}, " + f"quantization_config={self.model_config.quantization_config}, " # noqa f"enforce_eager={self.model_config.enforce_eager}, " f"enable_return_routed_experts={self.model_config.enable_return_routed_experts}, " # noqa f"kv_cache_dtype={self.cache_config.cache_dtype}, " diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d498135ce..90e10c1cb 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -112,6 +112,7 @@ from vllm.v1.sample.logits_processor import LogitsProcessor from vllm.version import __version__ as VLLM_VERSION if TYPE_CHECKING: + from vllm.config.quantization import OnlineQuantizationConfigArgs from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.model_loader import LoadFormats from vllm.usage.usage_lib import UsageContext @@ -483,6 +484,7 @@ class EngineArgs: hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides") tokenizer_revision: str | None = ModelConfig.tokenizer_revision quantization: QuantizationMethods | str | None = ModelConfig.quantization + quantization_config: "dict[str, Any] | OnlineQuantizationConfigArgs | None" = None allow_deprecated_quantization: bool = ModelConfig.allow_deprecated_quantization enforce_eager: bool = ModelConfig.enforce_eager disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce @@ -661,6 +663,12 @@ class EngineArgs: if isinstance(self.ir_op_priority, dict): self.ir_op_priority = IrOpPriorityConfig(**self.ir_op_priority) + from vllm.config.quantization import resolve_online_quant_config + + self.quantization_config = resolve_online_quant_config( + self.quantization, self.quantization_config + ) + # Setup plugins from vllm.plugins import load_general_plugins @@ -1431,6 +1439,7 @@ class EngineArgs: tokenizer_revision=self.tokenizer_revision, max_model_len=self.max_model_len, quantization=self.quantization, + quantization_config=self.quantization_config, allow_deprecated_quantization=self.allow_deprecated_quantization, enforce_eager=self.enforce_eager, enable_return_routed_experts=self.enable_return_routed_experts, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b9eea8745..1be2cdd5c 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -34,6 +34,9 @@ from vllm.config.model import ( RunnerOption, TokenizerMode, ) +from vllm.config.quantization import ( + OnlineQuantizationConfigArgs, +) from vllm.distributed.weight_transfer.base import ( WeightTransferInitRequest, WeightTransferUpdateRequest, @@ -247,6 +250,9 @@ class LLM: attention_config: dict[str, Any] | AttentionConfig | None = None, kv_cache_memory_bytes: int | None = None, compilation_config: int | dict[str, Any] | CompilationConfig | None = None, + quantization_config: dict[str, Any] + | OnlineQuantizationConfigArgs + | None = None, logits_processors: list[str | type[LogitsProcessor]] | None = None, **kwargs: Any, ) -> None: @@ -367,6 +373,7 @@ class LLM: profiler_config=profiler_config_instance, attention_config=attention_config_instance, compilation_config=compilation_config_instance, + quantization_config=quantization_config, logits_processors=logits_processors, **kwargs, ) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 9aceb3be0..d897e0d99 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -33,6 +33,13 @@ QuantizationMethods = Literal[ "mxfp8", "petit_nvfp4", "cpu_awq", + "online", + # Below are values of the OnlineQuantScheme enum, specified as strings to + # avoid circular import issues. This is here to provide a shortcut where + # the user can specify "LLM(..., quantization='fp8_per_tensor')" as + # shorthand for creating a more complicated online quant config object + "fp8_per_tensor", + "fp8_per_block", ] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -103,6 +110,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: raise ValueError(f"Invalid quantization method: {quantization}") # lazy import to avoid triggering `torch.compile` too early + from vllm.config.quantization import OnlineQuantScheme from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig from .awq import AWQConfig @@ -129,6 +137,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .moe_wna16 import MoeWNA16Config from .mxfp4 import Mxfp4Config from .mxfp8 import Mxfp8Config + from .online.base import OnlineQuantizationConfig from .petit import PetitNvFp4Config from .torchao import TorchAOConfig @@ -157,7 +166,20 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "mxfp8": Mxfp8Config, "petit_nvfp4": PetitNvFp4Config, "cpu_awq": CPUAWQConfig, + "online": OnlineQuantizationConfig, } + + # Below are values of the OnlineQuantScheme enum. This is here to provide + # a shortcut where the user can specify + # "LLM(..., quantization='fp8_per_tensor')" as shorthand for creating a + # more complicated online quant config object + for scheme in OnlineQuantScheme: + assert scheme.value not in method_to_config, ( + f"Online quant scheme {scheme.value!r} conflicts with an " + f"existing quantization method" + ) + method_to_config[scheme.value] = OnlineQuantizationConfig + # Update the `method_to_config` with customized quantization methods. method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 259a7d1f6..2816f7656 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -497,6 +497,8 @@ class Fp8LinearMethod(LinearMethodBase): return self.fp8_linear.apply_weights(layer, x, bias) +# TODO(future PR): remove this class in favor of +# online/fp8.py::Fp8PerTensorOnlineLinearMethod class Fp8OnlineLinearMethod(Fp8LinearMethod): """Online version of Fp8LinearMethod which loads a full precision checkpoint and quantizes weights during loading.""" @@ -919,6 +921,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ) +# TODO(future PR): remove this class in favor of +# online/fp8.py::Fp8PerTensorOnlineMoEMethod class Fp8OnlineMoEMethod(Fp8MoEMethod): """MoE method for online FP8 quantization. Supports loading quantized FP16/BF16 model checkpoints with dynamic diff --git a/vllm/model_executor/layers/quantization/online/__init__.py b/vllm/model_executor/layers/quantization/online/__init__.py new file mode 100644 index 000000000..208f01a7c --- /dev/null +++ b/vllm/model_executor/layers/quantization/online/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm/model_executor/layers/quantization/online/base.py b/vllm/model_executor/layers/quantization/online/base.py new file mode 100644 index 000000000..87997f8ef --- /dev/null +++ b/vllm/model_executor/layers/quantization/online/base.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any + +import torch + +from vllm.config.quantization import ( + OnlineQuantizationConfigArgs, + OnlineQuantScheme, +) +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, +) +from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( + UnquantizedFusedMoEMethod, +) +from vllm.model_executor.layers.linear import ( + LinearBase, + UnquantizedLinearMethod, +) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + should_ignore_layer, +) +from vllm.model_executor.layers.quantization.online.fp8 import ( + Fp8PerBlockOnlineLinearMethod, + Fp8PerBlockOnlineMoEMethod, + Fp8PerTensorOnlineLinearMethod, + Fp8PerTensorOnlineMoEMethod, +) + + +class OnlineQuantizationConfig(QuantizationConfig): + """Model-level config class for online quantization (quantize fp16/bf16 weights + during model loading, without requiring a pre-quantized checkpoint).""" + + def __init__( + self, + args: OnlineQuantizationConfigArgs, + ) -> None: + super().__init__() + if ( + args.global_scheme is None + and args.linear_scheme_override is None + and args.moe_scheme_override is None + ): + raise ValueError( + "OnlineQuantizationConfig requires at least one of " + "global_scheme, linear_scheme_override, or " + "moe_scheme_override to be set." + ) + self.args = args + self.quant_scheme = args.global_scheme + self.ignored_layers: list[str] = args.ignore + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "online" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + # Note: as more online quant schemes will be added, this + # value will become the minimum across all supported schemes. + return 75 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "OnlineQuantizationConfig": + raise NotImplementedError( + "OnlineQuantizationConfig does not support loading from a " + "checkpoint config. Use quantization_config or " + "quantization='fp8_per_tensor'/'fp8_per_block' instead." + ) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> "QuantizeMethodBase | None": + if isinstance(layer, LinearBase): + if should_ignore_layer( + prefix, + ignore=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + ): + return UnquantizedLinearMethod() + + linear_scheme = self.args.linear_scheme_override or self.args.global_scheme + if linear_scheme == OnlineQuantScheme.FP8_PER_BLOCK: + return Fp8PerBlockOnlineLinearMethod() + else: + return Fp8PerTensorOnlineLinearMethod() + elif isinstance(layer, FusedMoE): + if should_ignore_layer( + prefix, + ignore=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + ): + return UnquantizedFusedMoEMethod(layer.moe_config) + + moe_scheme = self.args.moe_scheme_override or self.args.global_scheme + if moe_scheme == OnlineQuantScheme.FP8_PER_BLOCK: + return Fp8PerBlockOnlineMoEMethod(layer=layer) + else: + return Fp8PerTensorOnlineMoEMethod(layer=layer) + return None diff --git a/vllm/model_executor/layers/quantization/online/fp8.py b/vllm/model_executor/layers/quantization/online/fp8.py new file mode 100644 index 000000000..941ae25b1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/online/fp8.py @@ -0,0 +1,632 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import TYPE_CHECKING + +import torch +from torch.nn import Module + +if TYPE_CHECKING: + import vllm.model_executor.layers.fused_moe.modular_kernel as mk + from vllm.model_executor.layers.fused_moe import FusedMoE + from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, + ) + from vllm.model_executor.layers.fused_moe.oracle.fp8 import Fp8MoeBackend + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops +from vllm.model_executor.kernels.linear import init_fp8_linear_kernel +from vllm.model_executor.layers.fused_moe import ( + FusedMoEMethodBase, +) +from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( + select_fp8_moe_backend, +) +from vllm.model_executor.layers.linear import ( + LinearMethodBase, +) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + W8A8BlockFp8LinearOp, + maybe_post_process_fp8_weight_block, + process_fp8_weight_block_strategy, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + kFp8Dynamic128Sym, + kFp8DynamicTensorSym, + kFp8DynamicTokenSym, + kFp8Static128BlockSym, + kFp8StaticTensorSym, +) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + cutlass_block_fp8_supported, + cutlass_fp8_supported, +) +from vllm.model_executor.model_loader.reload.layerwise import ( + initialize_online_processing, +) +from vllm.model_executor.parameter import ModelWeightParameter +from vllm.model_executor.utils import replace_parameter, set_weight_attrs +from vllm.platforms import current_platform +from vllm.utils.deep_gemm import is_deep_gemm_supported, per_block_cast_to_fp8 + +# --------------------------------------------------------------------------- +# Online FP8 Linear Methods +# --------------------------------------------------------------------------- + + +class _Fp8OnlineLinearBase(LinearMethodBase): + """Shared base for online FP8 linear methods. Loads fp16/bf16 checkpoint + weights onto meta device and materializes them just-in-time.""" + + uses_meta_device: bool = True + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + layer.weight_block_size = None + + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + device="meta", # materialized and processed during loading + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + initialize_online_processing(layer) + + +class Fp8PerTensorOnlineLinearMethod(_Fp8OnlineLinearBase): + """Online tensorwise FP8 linear quantization. + Loads fp16/bf16 weights and quantizes them per-tensor during loading.""" + + def __init__(self): + self.out_dtype = torch.get_default_dtype() + + # Use per-token quantization for better perf if dynamic and cutlass + if cutlass_fp8_supported(): + activation_quant_key = kFp8DynamicTokenSym + else: + activation_quant_key = kFp8DynamicTensorSym + + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=activation_quant_key, + weight_quant_key=kFp8StaticTensorSym, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, + ) + + def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + + layer.input_scale = None + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) + + # Update layer with new values. + replace_parameter(layer, "weight", qweight.t().data) + replace_parameter(layer, "weight_scale", weight_scale.data) + + # Prevent duplicate processing (e.g., during weight reload) + layer._already_called_process_weights_after_loading = True + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + # if batch invariant mode is enabled, use BF16 dequant + if envs.VLLM_BATCH_INVARIANT: + weight_fp8 = layer.weight.to(torch.bfloat16) + weight_scale = layer.weight_scale.to(torch.bfloat16) + if weight_scale.numel() == 1: + # Per-tensor: simple scalar multiplication + weight_bf16 = weight_fp8 * weight_scale + else: + # Multiple scales (fused modules like QKV) + if ( + weight_scale.dim() == 1 + and weight_scale.shape[0] == weight_fp8.shape[0] + ): + # Per-row scaling + weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1) + else: + # Fallback + weight_bf16 = weight_fp8 * weight_scale + return torch.nn.functional.linear(x, weight_bf16.t(), bias) + + return self.fp8_linear.apply_weights(layer, x, bias) + + +class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase): + """Online blockwise FP8 linear quantization. + Loads fp16/bf16 weights and quantizes them per-block during loading.""" + + def __init__(self): + self.out_dtype = torch.get_default_dtype() + self.weight_block_size = [128, 128] + + self.use_deep_gemm = is_deep_gemm_supported() + self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled() + self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() + + self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(*self.weight_block_size), + act_quant_group_shape=GroupShape(1, self.weight_block_size[0]), + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported, + use_deep_gemm=self.use_deep_gemm, + ) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + super().create_weights( + layer, + input_size_per_partition, + output_partition_sizes, + input_size, + output_size, + params_dtype, + **extra_weight_attrs, + ) + layer.weight_block_size = self.weight_block_size + + def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + + layer.input_scale = None + block_size = self.weight_block_size + + qweight, weight_scale_inv = per_block_cast_to_fp8( + layer.weight, block_size=block_size, use_ue8m0=False + ) + + qweight, weight_scale_inv = process_fp8_weight_block_strategy( + qweight, weight_scale_inv + ) + + replace_parameter(layer, "weight", qweight.data) + replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data) + + maybe_post_process_fp8_weight_block(layer) + + # Prevent duplicate processing (e.g., during weight reload) + layer._already_called_process_weights_after_loading = True + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + assert self.weight_block_size is not None + + # Note: batch invariance already handled in the function below + return self.w8a8_block_fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale_inv, + input_scale=layer.input_scale, + bias=bias, + ) + + +# --------------------------------------------------------------------------- +# Online FP8 MoE Methods +# --------------------------------------------------------------------------- + + +class _Fp8OnlineMoEBase(FusedMoEMethodBase): + """Shared base for online FP8 MoE methods. Loads fp16/bf16 checkpoint + weights onto meta device and materializes them just-in-time.""" + + uses_meta_device: bool = True + + # Declared here for mypy; actual values are set in __init__. + fp8_backend: "Fp8MoeBackend" + experts_cls: "type[mk.FusedMoEExperts] | None" + weight_scale_name: str + weight_block_size: list[int] | None + moe: "FusedMoEConfig" + is_monolithic: bool + moe_quant_config: "FusedMoEQuantConfig | None" + moe_kernel: "mk.FusedMoEKernel | None" + + def __init__( + self, + *, + weight_block_size: list[int] | None, + layer: torch.nn.Module, + ): + super().__init__(layer.moe_config) + self.weight_block_size = weight_block_size + self.block_quant: bool = self.weight_block_size is not None + self.weight_scale_name = ( + "weight_scale_inv" if self.block_quant else "weight_scale" + ) + + # Set weight key and activation key for kernel compatibility + if self.block_quant: + weight_key = kFp8Static128BlockSym + activation_key = kFp8Dynamic128Sym + else: + weight_key = kFp8StaticTensorSym + activation_key = kFp8DynamicTensorSym + + # Select Fp8 MoE backend + self.fp8_backend, self.experts_cls = select_fp8_moe_backend( + config=self.moe, + weight_key=weight_key, + activation_key=activation_key, + allow_vllm_cutlass=False, + ) + + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + device="meta", + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + device="meta", # materialized and processed during loading + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # BIASES (for models like GPT-OSS that have biased MoE) + if self.moe.has_bias: + w13_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + device="meta", # materialized and processed during loading + dtype=layer.orig_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + + w2_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + device="meta", # materialized and processed during loading + dtype=layer.orig_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + + layer.w13_input_scale = None + layer.w2_input_scale = None + + initialize_online_processing(layer) + + def _setup_kernel( + self, + layer: "FusedMoE", + w13: torch.Tensor, + w2: torch.Tensor, + w13_scale: torch.Tensor, + w2_scale: torch.Tensor, + w13_input_scale: torch.Tensor | None, + w2_input_scale: torch.Tensor | None, + ) -> None: + from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( + convert_to_fp8_moe_kernel_format, + make_fp8_moe_kernel, + ) + + # Shuffle weights to runtime format. + w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format( + fp8_backend=self.fp8_backend, + layer=layer, + w13=w13, + w2=w2, + w13_scale=w13_scale, + w2_scale=w2_scale, + w13_input_scale=w13_input_scale, + w2_input_scale=w2_input_scale, + ) + + # Replace parameters with updated versions. Note that this helper + # function ensures the replacement is compatible with RL weight reloads. + replace_parameter(layer, "w13_weight", w13) + replace_parameter(layer, "w2_weight", w2) + replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale) + replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale) + + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + if self.moe_quant_config: + assert self.experts_cls is not None + self.moe_kernel = make_fp8_moe_kernel( + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + fp8_backend=self.fp8_backend, + experts_cls=self.experts_cls, + routing_tables=layer._maybe_init_expert_routing_tables(), + shared_experts=layer.shared_experts, + ) + + def maybe_make_prepare_finalize( + self, + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + ) -> "mk.FusedMoEPrepareAndFinalizeModular | None": + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel " + "initialization logic. This function should not be called." + ) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> "FusedMoEQuantConfig": + from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( + make_fp8_moe_quant_config, + ) + + w1_scale = getattr(layer, f"w13_{self.weight_scale_name}") + w2_scale = getattr(layer, f"w2_{self.weight_scale_name}") + a1_scale = layer.w13_input_scale + a2_scale = layer.w2_input_scale + + quant_config = make_fp8_moe_quant_config( + fp8_backend=self.fp8_backend, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=self.weight_block_size, + ) + + # Inject biases into the quant config if the model has them + # (e.g. GPT-OSS biased MoE) + if quant_config is not None and self.moe.has_bias: + w13_bias = getattr(layer, "w13_bias", None) + w2_bias = getattr(layer, "w2_bias", None) + if w13_bias is not None: + quant_config._w1.bias = w13_bias + if w2_bias is not None: + quant_config._w2.bias = w2_bias + + return quant_config + + @property + def supports_eplb(self) -> bool: + return True + + def apply_monolithic( + self, + layer: "FusedMoE", + x: torch.Tensor, + router_logits: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + assert self.is_monolithic + assert self.moe_kernel is not None + return self.moe_kernel.apply_monolithic( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + e_score_correction_bias=layer.e_score_correction_bias, + routed_scaling_factor=layer.routed_scaling_factor, + ) + + def apply( + self, + layer: "FusedMoE", + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + assert not self.is_monolithic + assert self.moe_kernel is not None + return self.moe_kernel.apply( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + shared_experts_input=shared_experts_input, + ) + + +class Fp8PerTensorOnlineMoEMethod(_Fp8OnlineMoEBase): + """Online tensorwise FP8 MoE quantization. + Loads fp16/bf16 weights and quantizes them per-tensor during loading.""" + + def __init__( + self, + *, + layer: torch.nn.Module, + ): + super().__init__( + weight_block_size=None, + layer=layer, + ) + + def process_weights_after_loading(self, layer: Module) -> None: + # TODO(@ksayers): inplace fp8 quant kernel, initialize scales with ones + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + + # If checkpoint is fp16, quantize in place. + fp8_dtype = current_platform.fp8_dtype() + w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) + w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype) + w13_scale = torch.ones( + layer.num_experts, device=w13.device, dtype=torch.float32 + ) + w2_scale = torch.ones(layer.num_experts, device=w2.device, dtype=torch.float32) + layer.w13_input_scale = None + layer.w2_input_scale = None + + for expert in range(layer.local_num_experts): + w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant( + layer.w13_weight[expert, :, :] + ) + w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant( + layer.w2_weight[expert, :, :] + ) + + # Shuffle weights to runtime format and setup kernel. + self._setup_kernel( + layer, + w13, + w2, + w13_scale, + w2_scale, + w13_input_scale=layer.w13_input_scale, + w2_input_scale=layer.w2_input_scale, + ) + + # Prevent duplicate processing (e.g., during weight reload) + layer._already_called_process_weights_after_loading = True + + +class Fp8PerBlockOnlineMoEMethod(_Fp8OnlineMoEBase): + """Online blockwise FP8 MoE quantization. + Loads fp16/bf16 weights and quantizes them per-block during loading.""" + + def __init__( + self, + *, + layer: torch.nn.Module, + ): + super().__init__( + weight_block_size=[128, 128], + layer=layer, + ) + + def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + + fp8_dtype = current_platform.fp8_dtype() + w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) + w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype) + + block_size = self.weight_block_size + assert block_size is not None + block_n, block_k = block_size + + # Create block-shaped scales (computed here rather than in + # create_weights because online quant doesn't need them until now). + num_experts = layer.local_num_experts + _, w13_out, w13_in = layer.w13_weight.shape + _, w2_out, w2_in = layer.w2_weight.shape + + w13_scale = torch.ones( + num_experts, + (w13_out + block_n - 1) // block_n, + (w13_in + block_k - 1) // block_k, + dtype=torch.float32, + device=w13.device, + ) + w2_scale = torch.ones( + num_experts, + (w2_out + block_n - 1) // block_n, + (w2_in + block_k - 1) // block_k, + dtype=torch.float32, + device=w2.device, + ) + + for expert in range(num_experts): + w13[expert], w13_scale[expert] = per_block_cast_to_fp8( + layer.w13_weight[expert], + block_size=block_size, + use_ue8m0=False, + ) + w2[expert], w2_scale[expert] = per_block_cast_to_fp8( + layer.w2_weight[expert], + block_size=block_size, + use_ue8m0=False, + ) + + layer.weight_block_size = block_size + + # Shuffle weights to runtime format and setup kernel. + self._setup_kernel( + layer, + w13, + w2, + w13_scale, + w2_scale, + layer.w13_input_scale, + layer.w2_input_scale, + ) + + # Prevent duplicate processing (e.g., during weight reload) + layer._already_called_process_weights_after_loading = True diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 2d17e611d..e3689dcd8 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -296,6 +296,13 @@ def get_quant_config( ) if hf_quant_config is not None: + if model_config.quantization_config is not None: + raise ValueError( + "Setting `quantization_config` for online " + "quantization when the model checkpoint already " + "has a `quantization_config` is not supported" + ) + # For modelopt_mixed, config.json's quantization_config may or may # not contain the per-layer quantized_layers map. Newer checkpoints # embed it directly; older ones keep it only in hf_quant_config.json. @@ -319,6 +326,12 @@ def get_quant_config( quantization_config_file = hf_overrides.get("quantization_config_file", None) if quantization_config_file is not None: if hasattr(quant_cls, "from_config_file"): + if model_config.quantization_config is not None: + raise ValueError( + "Setting `quantization_config` for online " + "quantization when the model checkpoint already " + "has a `quantization_config` is not supported" + ) return quant_cls.from_config_file(quantization_config_file) else: raise NotImplementedError( @@ -329,6 +342,12 @@ def get_quant_config( quantization_config_json = hf_overrides.get("quantization_config_dict_json", None) if quantization_config_json is not None: if hasattr(quant_cls, "from_config_dict_json"): + if model_config.quantization_config is not None: + raise ValueError( + "Setting `quantization_config` for online " + "quantization when the model checkpoint already " + "has a `quantization_config` is not supported" + ) return quant_cls.from_config_dict_json(quantization_config_json) else: raise NotImplementedError( @@ -337,6 +356,19 @@ def get_quant_config( f"{quant_cls}" ) + # Online quantization doesn't read from checkpoint configs — it quantizes + # fp16/bf16 weights on the fly during loading. + if model_config.quantization_config is not None: + from vllm.config.quantization import OnlineQuantizationConfigArgs + from vllm.model_executor.layers.quantization.online.base import ( + OnlineQuantizationConfig, + ) + + assert isinstance( + model_config.quantization_config, OnlineQuantizationConfigArgs + ) + return OnlineQuantizationConfig(args=model_config.quantization_config) + # Inflight BNB quantization if model_config.quantization == "bitsandbytes": return quant_cls.from_config({})