[Frontend] new online quantization frontend (#38138)
Signed-off-by: Vasiliy Kuznetsov <vasiliy@meta.com>
This commit is contained in:
committed by
GitHub
parent
97f92c6b47
commit
7b1a7423be
179
tests/quantization/test_online.py
Normal file
179
tests/quantization/test_online.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests online quantization."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.quantization.utils import (
|
||||
_test_online_quant_peak_mem_impl,
|
||||
is_quant_method_supported,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.online.fp8 import (
|
||||
Fp8PerBlockOnlineLinearMethod,
|
||||
Fp8PerBlockOnlineMoEMethod,
|
||||
Fp8PerTensorOnlineLinearMethod,
|
||||
Fp8PerTensorOnlineMoEMethod,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("fp8"),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"quant_scheme,online_quant_args,expected_linear_cls,expected_moe_cls",
|
||||
[
|
||||
# simple case - quantization='fp8_per_tensor'
|
||||
(
|
||||
"fp8_per_tensor",
|
||||
None,
|
||||
Fp8PerTensorOnlineLinearMethod,
|
||||
Fp8PerTensorOnlineMoEMethod,
|
||||
),
|
||||
# simple case - quantization='fp8_per_block'
|
||||
(
|
||||
"fp8_per_block",
|
||||
None,
|
||||
Fp8PerBlockOnlineLinearMethod,
|
||||
Fp8PerBlockOnlineMoEMethod,
|
||||
),
|
||||
# quantization='online with linear_scheme_override and
|
||||
# moe_scheme_override
|
||||
(
|
||||
"online",
|
||||
{
|
||||
"linear_scheme_override": "fp8_per_block",
|
||||
"moe_scheme_override": "fp8_per_tensor",
|
||||
},
|
||||
Fp8PerBlockOnlineLinearMethod,
|
||||
Fp8PerTensorOnlineMoEMethod,
|
||||
),
|
||||
# ignore with direct layer name
|
||||
(
|
||||
"fp8_per_tensor",
|
||||
# qkv_proj is fused from q_proj/k_proj/v_proj, so currently the
|
||||
# ignore regex must match the unfused shard names
|
||||
# TODO(future PR): also make 're:.*qkv_proj.*' work
|
||||
{"ignore": ["model.layers.1.self_attn.o_proj", "re:.*[qkv]_proj"]},
|
||||
Fp8PerTensorOnlineLinearMethod,
|
||||
Fp8PerTensorOnlineMoEMethod,
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
|
||||
)
|
||||
def test_online_quantization(
|
||||
vllm_runner,
|
||||
quant_scheme: str,
|
||||
online_quant_args: dict | None,
|
||||
expected_linear_cls,
|
||||
expected_moe_cls,
|
||||
use_rocm_aiter: bool,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""
|
||||
Tests that online quantization frontend configuration works -
|
||||
selecting quant schemes, overriding quant schemes by type, ignoring
|
||||
layers.
|
||||
|
||||
Does not test performance, peak memory usage, etc.
|
||||
"""
|
||||
|
||||
if use_rocm_aiter:
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
# `LLM.apply_model` requires pickling a function.
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
# a tiny model with both dense and MoE layers
|
||||
model_name = "ibm-granite/granite-3.0-1b-a400m-base"
|
||||
|
||||
runner_kwargs = dict(
|
||||
quantization=quant_scheme,
|
||||
enforce_eager=True,
|
||||
)
|
||||
if online_quant_args is not None:
|
||||
runner_kwargs["quantization_config"] = online_quant_args
|
||||
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
**runner_kwargs,
|
||||
) as llm:
|
||||
|
||||
def check_model(model):
|
||||
# checks further down in the test case are hardcoded for this
|
||||
# model
|
||||
assert model_name == "ibm-granite/granite-3.0-1b-a400m-base"
|
||||
|
||||
o_proj = model.model.layers[0].self_attn.o_proj
|
||||
moe = model.model.layers[0].block_sparse_moe.experts
|
||||
|
||||
# o_proj and moe in layer 0 are always quantized (never ignored)
|
||||
# because of how we craft the test case inputs
|
||||
assert isinstance(o_proj.quant_method, expected_linear_cls)
|
||||
if moe is not None:
|
||||
assert isinstance(moe.quant_method, expected_moe_cls)
|
||||
|
||||
if current_platform.is_cuda():
|
||||
assert o_proj.weight.dtype == torch.float8_e4m3fn
|
||||
elif current_platform.is_rocm():
|
||||
assert o_proj.weight.dtype == current_platform.fp8_dtype()
|
||||
else:
|
||||
pytest.skip("Only runs on CUDA and ROCm.")
|
||||
|
||||
# Verify ignored layers are unquantized.
|
||||
if isinstance(online_quant_args, dict) and "ignore" in online_quant_args:
|
||||
# only .*1.self_attn_o_proj is skipped
|
||||
for layer_idx in range(len(model.model.layers)):
|
||||
o_proj = model.model.layers[layer_idx].self_attn.o_proj
|
||||
if layer_idx == 1:
|
||||
assert isinstance(o_proj.quant_method, UnquantizedLinearMethod)
|
||||
else:
|
||||
assert isinstance(o_proj.quant_method, expected_linear_cls)
|
||||
|
||||
# every .*self_attn.qkv_proj is skipped
|
||||
for layer_idx in range(len(model.model.layers)):
|
||||
qkv_proj = model.model.layers[layer_idx].self_attn.qkv_proj
|
||||
assert isinstance(qkv_proj.quant_method, UnquantizedLinearMethod)
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
outputs = llm.generate_greedy(["Hello my name is"], max_tokens=4)
|
||||
print(outputs[0][1])
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("fp8"),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_online_quant_peak_mem(
|
||||
vllm_runner,
|
||||
caplog_mp_spawn,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
_test_online_quant_peak_mem_impl(
|
||||
"fp8_per_tensor", vllm_runner, caplog_mp_spawn, monkeypatch
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("fp8"),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_online_quant_load_format_dummy(
|
||||
vllm_runner,
|
||||
monkeypatch,
|
||||
caplog,
|
||||
) -> None:
|
||||
with vllm_runner(
|
||||
"ibm-granite/granite-3.0-1b-a400m-base",
|
||||
quantization="fp8_per_tensor",
|
||||
enforce_eager=True,
|
||||
load_format="dummy",
|
||||
) as llm:
|
||||
outputs = llm.generate_greedy(["The future of AI is"], max_tokens=4)
|
||||
print(outputs[0][1])
|
||||
@@ -1,6 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import logging
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.model_executor.layers.quantization import get_quantization_config
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -21,3 +25,74 @@ def is_quant_method_supported(quant_method: str) -> bool:
|
||||
min_capability = get_quantization_config(quant_method).get_min_capability()
|
||||
|
||||
return capability.to_int() >= min_capability
|
||||
|
||||
|
||||
def _test_online_quant_peak_mem_impl(
|
||||
quantization_arg_value,
|
||||
vllm_runner,
|
||||
caplog_mp_spawn,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
# Note: `allenai/OLMoE-1B-7B-0125-Instruct` was selected because:
|
||||
# 1. it covers both Linear and MoE paths
|
||||
# 2. it is already used by other tests in CI, so adding it here
|
||||
# does not increase disk space for CI runners
|
||||
# I really wanted to use `ibm-granite/granite-3.0-1b-a400m-base`
|
||||
# which I think is the smallest MoE model in vLLM (2.5 GiB bf16,
|
||||
# 1.3 GiB fp8), but could not as adding one more model makes CI
|
||||
# run out of disk space.
|
||||
model_name = "allenai/OLMoE-1B-7B-0125-Instruct"
|
||||
|
||||
# Force spawn to ensure caplog_mp_spawn works consistently
|
||||
# (it relies on VLLM_LOGGING_CONFIG_PATH which spawn reads but fork ignores)
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
|
||||
with (
|
||||
caplog_mp_spawn(logging.DEBUG) as log_holder,
|
||||
vllm_runner(
|
||||
model_name,
|
||||
quantization=quantization_arg_value,
|
||||
enforce_eager=True,
|
||||
) as llm,
|
||||
):
|
||||
outputs = llm.generate_greedy(["The future of AI is"], max_tokens=4)
|
||||
print(outputs[0][1])
|
||||
|
||||
log_text = log_holder.text
|
||||
|
||||
# Parse memory usage from captured logs
|
||||
model_memory_gib = None
|
||||
peak_memory_gib = None
|
||||
for line in log_text.splitlines():
|
||||
if model_memory_gib is None:
|
||||
match = re.search(r"Model loading took ([\d.]+) GiB memory", line)
|
||||
if match:
|
||||
model_memory_gib = float(match.group(1))
|
||||
if peak_memory_gib is None:
|
||||
match = re.search(
|
||||
r"Peak GPU memory after loading weights: ([\d.]+) GiB", line
|
||||
)
|
||||
if match:
|
||||
peak_memory_gib = float(match.group(1))
|
||||
|
||||
assert model_memory_gib is not None, "Could not find model loading memory log"
|
||||
assert peak_memory_gib is not None, "Could not find peak memory log"
|
||||
print(f"GPU memory used after loading weights: {model_memory_gib} GiB")
|
||||
print(f"Peak GPU memory usage while loading weights: {peak_memory_gib} GiB")
|
||||
|
||||
# model specific, allenai/OLMoE-1B-7B-0125-Instruct fp8 online quant
|
||||
# uses 6.65 GiB for weight loading (bf16 checkpoint is ~12.89 GiB)
|
||||
expected_model_memory_gib = 6.7
|
||||
|
||||
# for allenai/OLMoE-1B-7B-0125-Instruct the number we see today is 9.06
|
||||
# GiB, which is 1.36x above model_memory_gib. A slightly higher number is
|
||||
# expected as when we load and quantize weights in a streaming fashion we
|
||||
# need to have individual weights in bf16 + fp8 alive at the same time.
|
||||
expected_peak_memory_gib = expected_model_memory_gib * 1.4
|
||||
|
||||
assert model_memory_gib < expected_model_memory_gib, (
|
||||
f"{model_memory_gib=} higher than {expected_model_memory_gib}"
|
||||
)
|
||||
assert peak_memory_gib < expected_peak_memory_gib, (
|
||||
f"{peak_memory_gib=} higher than {expected_peak_memory_gib}"
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ from vllm.config.multimodal import (
|
||||
MultiModalConfig,
|
||||
)
|
||||
from vllm.config.pooler import PoolerConfig
|
||||
from vllm.config.quantization import OnlineQuantizationConfigArgs
|
||||
from vllm.config.scheduler import RunnerType
|
||||
from vllm.config.utils import config, getattr_iter
|
||||
from vllm.logger import init_logger
|
||||
@@ -199,6 +200,10 @@ class ModelConfig:
|
||||
`quantization_config` attribute in the model config file. If that is
|
||||
`None`, we assume the model weights are not quantized and use `dtype` to
|
||||
determine the data type of the weights."""
|
||||
quantization_config: dict[str, Any] | OnlineQuantizationConfigArgs | None = None
|
||||
"""Arguments for online quantization.
|
||||
Auto-created when `quantization` equals to one of the string values of
|
||||
the `OnlineQuantScheme` enum."""
|
||||
allow_deprecated_quantization: bool = False
|
||||
"""Whether to allow deprecated quantization methods."""
|
||||
enforce_eager: bool = False
|
||||
|
||||
121
vllm/config/quantization.py
Normal file
121
vllm/config/quantization.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
|
||||
class OnlineQuantScheme(Enum):
|
||||
"""Supported online quantization schemes."""
|
||||
|
||||
# fp8, weights and activations scaled per-tensor
|
||||
FP8_PER_TENSOR = "fp8_per_tensor"
|
||||
|
||||
# fp8, activations scaled in blocks of 1x128 elements, weights scaled in
|
||||
# blocks of 128x128 elements (popularized by DeepSeek)
|
||||
FP8_PER_BLOCK = "fp8_per_block"
|
||||
|
||||
# TODO(future PRs): add more online quant schemes here: mxfp8, etc
|
||||
|
||||
|
||||
@config
|
||||
class OnlineQuantizationConfigArgs:
|
||||
"""Configuration for online quantization.
|
||||
|
||||
Controls how ``OnlineQuantizationConfig`` is applied to a model.
|
||||
At least one of ``global_scheme``, ``linear_scheme_override``, or
|
||||
``moe_scheme_override`` must be set.
|
||||
"""
|
||||
|
||||
global_scheme: OnlineQuantScheme | None = None
|
||||
"""Quantization scheme applied to every supported layer."""
|
||||
|
||||
linear_scheme_override: OnlineQuantScheme | None = None
|
||||
"""Quantization scheme override for ``LinearBase`` layers."""
|
||||
|
||||
moe_scheme_override: OnlineQuantScheme | None = None
|
||||
"""Quantization scheme override for ``FusedMoE`` layers."""
|
||||
|
||||
ignore: list[str] = Field(default_factory=list)
|
||||
"""Layers to skip quantization for. Supports exact names and regex
|
||||
patterns with ``re:`` prefix (e.g. ``re:.*attn.*``), consistent with
|
||||
compressed_tensors layer skipping."""
|
||||
|
||||
@field_validator(
|
||||
"global_scheme", "linear_scheme_override", "moe_scheme_override", mode="before"
|
||||
)
|
||||
@classmethod
|
||||
def _coerce_scheme(
|
||||
cls, v: str | OnlineQuantScheme | None
|
||||
) -> OnlineQuantScheme | None:
|
||||
if isinstance(v, str):
|
||||
return OnlineQuantScheme(v)
|
||||
return v
|
||||
|
||||
|
||||
def resolve_online_quant_config(
|
||||
quantization: str | None,
|
||||
quantization_config: dict[str, Any] | OnlineQuantizationConfigArgs | None,
|
||||
) -> OnlineQuantizationConfigArgs | None:
|
||||
"""Resolve online quant scheme shorthand into a quantization config.
|
||||
|
||||
If ``quantization`` is an online quant scheme (e.g. ``'fp8_per_tensor'``),
|
||||
ensures ``quantization_config`` has a matching ``global_scheme`` and casts
|
||||
it to :class:`OnlineQuantizationConfigArgs` if needed.
|
||||
"""
|
||||
online_quant_values = {s.value for s in OnlineQuantScheme}
|
||||
valid_quantization_values = online_quant_values | {"online"}
|
||||
if quantization not in valid_quantization_values:
|
||||
if quantization_config is not None:
|
||||
raise ValueError(
|
||||
f"quantization_config is only supported when quantization "
|
||||
f"is one of {sorted(valid_quantization_values)}, "
|
||||
f"got quantization={quantization!r}"
|
||||
)
|
||||
return None
|
||||
|
||||
if quantization in online_quant_values:
|
||||
scheme = OnlineQuantScheme(quantization)
|
||||
|
||||
if quantization_config is None:
|
||||
quantization_config = {
|
||||
"global_scheme": scheme.value,
|
||||
}
|
||||
elif isinstance(quantization_config, OnlineQuantizationConfigArgs):
|
||||
if quantization_config.global_scheme is None:
|
||||
quantization_config.global_scheme = scheme
|
||||
elif quantization_config.global_scheme != scheme:
|
||||
raise ValueError(
|
||||
f"quantization={quantization!r} conflicts with "
|
||||
f"quantization_config.global_scheme="
|
||||
f"{quantization_config.global_scheme.value!r}. "
|
||||
f"These must match when both are specified."
|
||||
)
|
||||
elif isinstance(quantization_config, dict):
|
||||
existing = quantization_config.get("global_scheme")
|
||||
if existing is None:
|
||||
quantization_config["global_scheme"] = scheme.value
|
||||
else:
|
||||
# Coerce to enum for comparison
|
||||
existing_scheme = (
|
||||
OnlineQuantScheme(existing)
|
||||
if isinstance(existing, str)
|
||||
else existing
|
||||
)
|
||||
if existing_scheme != scheme:
|
||||
raise ValueError(
|
||||
f"quantization={quantization!r} conflicts "
|
||||
f"with quantization_config"
|
||||
f"['global_scheme']={existing!r}. "
|
||||
f"These must match when both are specified."
|
||||
)
|
||||
|
||||
# Cast dict to OnlineQuantizationConfigArgs
|
||||
if isinstance(quantization_config, dict):
|
||||
quantization_config = OnlineQuantizationConfigArgs(**quantization_config)
|
||||
|
||||
return quantization_config
|
||||
@@ -1713,6 +1713,7 @@ class VllmConfig:
|
||||
f"dcp_comm_backend={self.parallel_config.dcp_comm_backend}, " # noqa
|
||||
f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa
|
||||
f"quantization={self.model_config.quantization}, "
|
||||
f"quantization_config={self.model_config.quantization_config}, " # noqa
|
||||
f"enforce_eager={self.model_config.enforce_eager}, "
|
||||
f"enable_return_routed_experts={self.model_config.enable_return_routed_experts}, " # noqa
|
||||
f"kv_cache_dtype={self.cache_config.cache_dtype}, "
|
||||
|
||||
@@ -112,6 +112,7 @@ from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config.quantization import OnlineQuantizationConfigArgs
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.model_loader import LoadFormats
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
@@ -483,6 +484,7 @@ class EngineArgs:
|
||||
hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
|
||||
tokenizer_revision: str | None = ModelConfig.tokenizer_revision
|
||||
quantization: QuantizationMethods | str | None = ModelConfig.quantization
|
||||
quantization_config: "dict[str, Any] | OnlineQuantizationConfigArgs | None" = None
|
||||
allow_deprecated_quantization: bool = ModelConfig.allow_deprecated_quantization
|
||||
enforce_eager: bool = ModelConfig.enforce_eager
|
||||
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
|
||||
@@ -661,6 +663,12 @@ class EngineArgs:
|
||||
if isinstance(self.ir_op_priority, dict):
|
||||
self.ir_op_priority = IrOpPriorityConfig(**self.ir_op_priority)
|
||||
|
||||
from vllm.config.quantization import resolve_online_quant_config
|
||||
|
||||
self.quantization_config = resolve_online_quant_config(
|
||||
self.quantization, self.quantization_config
|
||||
)
|
||||
|
||||
# Setup plugins
|
||||
from vllm.plugins import load_general_plugins
|
||||
|
||||
@@ -1431,6 +1439,7 @@ class EngineArgs:
|
||||
tokenizer_revision=self.tokenizer_revision,
|
||||
max_model_len=self.max_model_len,
|
||||
quantization=self.quantization,
|
||||
quantization_config=self.quantization_config,
|
||||
allow_deprecated_quantization=self.allow_deprecated_quantization,
|
||||
enforce_eager=self.enforce_eager,
|
||||
enable_return_routed_experts=self.enable_return_routed_experts,
|
||||
|
||||
@@ -34,6 +34,9 @@ from vllm.config.model import (
|
||||
RunnerOption,
|
||||
TokenizerMode,
|
||||
)
|
||||
from vllm.config.quantization import (
|
||||
OnlineQuantizationConfigArgs,
|
||||
)
|
||||
from vllm.distributed.weight_transfer.base import (
|
||||
WeightTransferInitRequest,
|
||||
WeightTransferUpdateRequest,
|
||||
@@ -247,6 +250,9 @@ class LLM:
|
||||
attention_config: dict[str, Any] | AttentionConfig | None = None,
|
||||
kv_cache_memory_bytes: int | None = None,
|
||||
compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
|
||||
quantization_config: dict[str, Any]
|
||||
| OnlineQuantizationConfigArgs
|
||||
| None = None,
|
||||
logits_processors: list[str | type[LogitsProcessor]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@@ -367,6 +373,7 @@ class LLM:
|
||||
profiler_config=profiler_config_instance,
|
||||
attention_config=attention_config_instance,
|
||||
compilation_config=compilation_config_instance,
|
||||
quantization_config=quantization_config,
|
||||
logits_processors=logits_processors,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -33,6 +33,13 @@ QuantizationMethods = Literal[
|
||||
"mxfp8",
|
||||
"petit_nvfp4",
|
||||
"cpu_awq",
|
||||
"online",
|
||||
# Below are values of the OnlineQuantScheme enum, specified as strings to
|
||||
# avoid circular import issues. This is here to provide a shortcut where
|
||||
# the user can specify "LLM(..., quantization='fp8_per_tensor')" as
|
||||
# shorthand for creating a more complicated online quant config object
|
||||
"fp8_per_tensor",
|
||||
"fp8_per_block",
|
||||
]
|
||||
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
|
||||
|
||||
@@ -103,6 +110,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
raise ValueError(f"Invalid quantization method: {quantization}")
|
||||
|
||||
# lazy import to avoid triggering `torch.compile` too early
|
||||
from vllm.config.quantization import OnlineQuantScheme
|
||||
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
|
||||
|
||||
from .awq import AWQConfig
|
||||
@@ -129,6 +137,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
from .mxfp4 import Mxfp4Config
|
||||
from .mxfp8 import Mxfp8Config
|
||||
from .online.base import OnlineQuantizationConfig
|
||||
from .petit import PetitNvFp4Config
|
||||
from .torchao import TorchAOConfig
|
||||
|
||||
@@ -157,7 +166,20 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
"mxfp8": Mxfp8Config,
|
||||
"petit_nvfp4": PetitNvFp4Config,
|
||||
"cpu_awq": CPUAWQConfig,
|
||||
"online": OnlineQuantizationConfig,
|
||||
}
|
||||
|
||||
# Below are values of the OnlineQuantScheme enum. This is here to provide
|
||||
# a shortcut where the user can specify
|
||||
# "LLM(..., quantization='fp8_per_tensor')" as shorthand for creating a
|
||||
# more complicated online quant config object
|
||||
for scheme in OnlineQuantScheme:
|
||||
assert scheme.value not in method_to_config, (
|
||||
f"Online quant scheme {scheme.value!r} conflicts with an "
|
||||
f"existing quantization method"
|
||||
)
|
||||
method_to_config[scheme.value] = OnlineQuantizationConfig
|
||||
|
||||
# Update the `method_to_config` with customized quantization methods.
|
||||
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
|
||||
|
||||
|
||||
@@ -497,6 +497,8 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
|
||||
|
||||
# TODO(future PR): remove this class in favor of
|
||||
# online/fp8.py::Fp8PerTensorOnlineLinearMethod
|
||||
class Fp8OnlineLinearMethod(Fp8LinearMethod):
|
||||
"""Online version of Fp8LinearMethod which loads a full precision checkpoint
|
||||
and quantizes weights during loading."""
|
||||
@@ -919,6 +921,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
|
||||
# TODO(future PR): remove this class in favor of
|
||||
# online/fp8.py::Fp8PerTensorOnlineMoEMethod
|
||||
class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
"""MoE method for online FP8 quantization.
|
||||
Supports loading quantized FP16/BF16 model checkpoints with dynamic
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
116
vllm/model_executor/layers/quantization/online/base.py
Normal file
116
vllm/model_executor/layers/quantization/online/base.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config.quantization import (
|
||||
OnlineQuantizationConfigArgs,
|
||||
OnlineQuantScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
|
||||
UnquantizedFusedMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
should_ignore_layer,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.online.fp8 import (
|
||||
Fp8PerBlockOnlineLinearMethod,
|
||||
Fp8PerBlockOnlineMoEMethod,
|
||||
Fp8PerTensorOnlineLinearMethod,
|
||||
Fp8PerTensorOnlineMoEMethod,
|
||||
)
|
||||
|
||||
|
||||
class OnlineQuantizationConfig(QuantizationConfig):
|
||||
"""Model-level config class for online quantization (quantize fp16/bf16 weights
|
||||
during model loading, without requiring a pre-quantized checkpoint)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: OnlineQuantizationConfigArgs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if (
|
||||
args.global_scheme is None
|
||||
and args.linear_scheme_override is None
|
||||
and args.moe_scheme_override is None
|
||||
):
|
||||
raise ValueError(
|
||||
"OnlineQuantizationConfig requires at least one of "
|
||||
"global_scheme, linear_scheme_override, or "
|
||||
"moe_scheme_override to be set."
|
||||
)
|
||||
self.args = args
|
||||
self.quant_scheme = args.global_scheme
|
||||
self.ignored_layers: list[str] = args.ignore
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "online"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# Note: as more online quant schemes will be added, this
|
||||
# value will become the minimum across all supported schemes.
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "OnlineQuantizationConfig":
|
||||
raise NotImplementedError(
|
||||
"OnlineQuantizationConfig does not support loading from a "
|
||||
"checkpoint config. Use quantization_config or "
|
||||
"quantization='fp8_per_tensor'/'fp8_per_block' instead."
|
||||
)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if isinstance(layer, LinearBase):
|
||||
if should_ignore_layer(
|
||||
prefix,
|
||||
ignore=self.ignored_layers,
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
linear_scheme = self.args.linear_scheme_override or self.args.global_scheme
|
||||
if linear_scheme == OnlineQuantScheme.FP8_PER_BLOCK:
|
||||
return Fp8PerBlockOnlineLinearMethod()
|
||||
else:
|
||||
return Fp8PerTensorOnlineLinearMethod()
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if should_ignore_layer(
|
||||
prefix,
|
||||
ignore=self.ignored_layers,
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
):
|
||||
return UnquantizedFusedMoEMethod(layer.moe_config)
|
||||
|
||||
moe_scheme = self.args.moe_scheme_override or self.args.global_scheme
|
||||
if moe_scheme == OnlineQuantScheme.FP8_PER_BLOCK:
|
||||
return Fp8PerBlockOnlineMoEMethod(layer=layer)
|
||||
else:
|
||||
return Fp8PerTensorOnlineMoEMethod(layer=layer)
|
||||
return None
|
||||
632
vllm/model_executor/layers/quantization/online/fp8.py
Normal file
632
vllm/model_executor/layers/quantization/online/fp8.py
Normal file
@@ -0,0 +1,632 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import Fp8MoeBackend
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.model_executor.kernels.linear import init_fp8_linear_kernel
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
select_fp8_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
maybe_post_process_fp8_weight_block,
|
||||
process_fp8_weight_block_strategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8Static128BlockSym,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_block_fp8_supported,
|
||||
cutlass_fp8_supported,
|
||||
)
|
||||
from vllm.model_executor.model_loader.reload.layerwise import (
|
||||
initialize_online_processing,
|
||||
)
|
||||
from vllm.model_executor.parameter import ModelWeightParameter
|
||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_supported, per_block_cast_to_fp8
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Online FP8 Linear Methods
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _Fp8OnlineLinearBase(LinearMethodBase):
|
||||
"""Shared base for online FP8 linear methods. Loads fp16/bf16 checkpoint
|
||||
weights onto meta device and materializes them just-in-time."""
|
||||
|
||||
uses_meta_device: bool = True
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
device="meta", # materialized and processed during loading
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
initialize_online_processing(layer)
|
||||
|
||||
|
||||
class Fp8PerTensorOnlineLinearMethod(_Fp8OnlineLinearBase):
|
||||
"""Online tensorwise FP8 linear quantization.
|
||||
Loads fp16/bf16 weights and quantizes them per-tensor during loading."""
|
||||
|
||||
def __init__(self):
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
# Use per-token quantization for better perf if dynamic and cutlass
|
||||
if cutlass_fp8_supported():
|
||||
activation_quant_key = kFp8DynamicTokenSym
|
||||
else:
|
||||
activation_quant_key = kFp8DynamicTensorSym
|
||||
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=activation_quant_key,
|
||||
weight_quant_key=kFp8StaticTensorSym,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
layer.input_scale = None
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
||||
|
||||
# Update layer with new values.
|
||||
replace_parameter(layer, "weight", qweight.t().data)
|
||||
replace_parameter(layer, "weight_scale", weight_scale.data)
|
||||
|
||||
# Prevent duplicate processing (e.g., during weight reload)
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# if batch invariant mode is enabled, use BF16 dequant
|
||||
if envs.VLLM_BATCH_INVARIANT:
|
||||
weight_fp8 = layer.weight.to(torch.bfloat16)
|
||||
weight_scale = layer.weight_scale.to(torch.bfloat16)
|
||||
if weight_scale.numel() == 1:
|
||||
# Per-tensor: simple scalar multiplication
|
||||
weight_bf16 = weight_fp8 * weight_scale
|
||||
else:
|
||||
# Multiple scales (fused modules like QKV)
|
||||
if (
|
||||
weight_scale.dim() == 1
|
||||
and weight_scale.shape[0] == weight_fp8.shape[0]
|
||||
):
|
||||
# Per-row scaling
|
||||
weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
|
||||
else:
|
||||
# Fallback
|
||||
weight_bf16 = weight_fp8 * weight_scale
|
||||
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
|
||||
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
|
||||
|
||||
class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase):
|
||||
"""Online blockwise FP8 linear quantization.
|
||||
Loads fp16/bf16 weights and quantizes them per-block during loading."""
|
||||
|
||||
def __init__(self):
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.weight_block_size = [128, 128]
|
||||
|
||||
self.use_deep_gemm = is_deep_gemm_supported()
|
||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(*self.weight_block_size),
|
||||
act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
use_deep_gemm=self.use_deep_gemm,
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
super().create_weights(
|
||||
layer,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes,
|
||||
input_size,
|
||||
output_size,
|
||||
params_dtype,
|
||||
**extra_weight_attrs,
|
||||
)
|
||||
layer.weight_block_size = self.weight_block_size
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
layer.input_scale = None
|
||||
block_size = self.weight_block_size
|
||||
|
||||
qweight, weight_scale_inv = per_block_cast_to_fp8(
|
||||
layer.weight, block_size=block_size, use_ue8m0=False
|
||||
)
|
||||
|
||||
qweight, weight_scale_inv = process_fp8_weight_block_strategy(
|
||||
qweight, weight_scale_inv
|
||||
)
|
||||
|
||||
replace_parameter(layer, "weight", qweight.data)
|
||||
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
|
||||
|
||||
maybe_post_process_fp8_weight_block(layer)
|
||||
|
||||
# Prevent duplicate processing (e.g., during weight reload)
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.weight_block_size is not None
|
||||
|
||||
# Note: batch invariance already handled in the function below
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Online FP8 MoE Methods
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _Fp8OnlineMoEBase(FusedMoEMethodBase):
|
||||
"""Shared base for online FP8 MoE methods. Loads fp16/bf16 checkpoint
|
||||
weights onto meta device and materializes them just-in-time."""
|
||||
|
||||
uses_meta_device: bool = True
|
||||
|
||||
# Declared here for mypy; actual values are set in __init__.
|
||||
fp8_backend: "Fp8MoeBackend"
|
||||
experts_cls: "type[mk.FusedMoEExperts] | None"
|
||||
weight_scale_name: str
|
||||
weight_block_size: list[int] | None
|
||||
moe: "FusedMoEConfig"
|
||||
is_monolithic: bool
|
||||
moe_quant_config: "FusedMoEQuantConfig | None"
|
||||
moe_kernel: "mk.FusedMoEKernel | None"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
weight_block_size: list[int] | None,
|
||||
layer: torch.nn.Module,
|
||||
):
|
||||
super().__init__(layer.moe_config)
|
||||
self.weight_block_size = weight_block_size
|
||||
self.block_quant: bool = self.weight_block_size is not None
|
||||
self.weight_scale_name = (
|
||||
"weight_scale_inv" if self.block_quant else "weight_scale"
|
||||
)
|
||||
|
||||
# Set weight key and activation key for kernel compatibility
|
||||
if self.block_quant:
|
||||
weight_key = kFp8Static128BlockSym
|
||||
activation_key = kFp8Dynamic128Sym
|
||||
else:
|
||||
weight_key = kFp8StaticTensorSym
|
||||
activation_key = kFp8DynamicTensorSym
|
||||
|
||||
# Select Fp8 MoE backend
|
||||
self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
|
||||
config=self.moe,
|
||||
weight_key=weight_key,
|
||||
activation_key=activation_key,
|
||||
allow_vllm_cutlass=False,
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
layer.num_experts = num_experts
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
device="meta",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
device="meta", # materialized and processed during loading
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# BIASES (for models like GPT-OSS that have biased MoE)
|
||||
if self.moe.has_bias:
|
||||
w13_bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
device="meta", # materialized and processed during loading
|
||||
dtype=layer.orig_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_bias", w13_bias)
|
||||
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||
|
||||
w2_bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
device="meta", # materialized and processed during loading
|
||||
dtype=layer.orig_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_bias", w2_bias)
|
||||
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
initialize_online_processing(layer)
|
||||
|
||||
def _setup_kernel(
|
||||
self,
|
||||
layer: "FusedMoE",
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor | None,
|
||||
w2_input_scale: torch.Tensor | None,
|
||||
) -> None:
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
convert_to_fp8_moe_kernel_format,
|
||||
make_fp8_moe_kernel,
|
||||
)
|
||||
|
||||
# Shuffle weights to runtime format.
|
||||
w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
|
||||
fp8_backend=self.fp8_backend,
|
||||
layer=layer,
|
||||
w13=w13,
|
||||
w2=w2,
|
||||
w13_scale=w13_scale,
|
||||
w2_scale=w2_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
|
||||
# Replace parameters with updated versions. Note that this helper
|
||||
# function ensures the replacement is compatible with RL weight reloads.
|
||||
replace_parameter(layer, "w13_weight", w13)
|
||||
replace_parameter(layer, "w2_weight", w2)
|
||||
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
|
||||
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
|
||||
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.moe_kernel = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
experts_cls=self.experts_cls,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
shared_experts=layer.shared_experts,
|
||||
)
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> "mk.FusedMoEPrepareAndFinalizeModular | None":
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel "
|
||||
"initialization logic. This function should not be called."
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> "FusedMoEQuantConfig":
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
make_fp8_moe_quant_config,
|
||||
)
|
||||
|
||||
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
|
||||
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
|
||||
a1_scale = layer.w13_input_scale
|
||||
a2_scale = layer.w2_input_scale
|
||||
|
||||
quant_config = make_fp8_moe_quant_config(
|
||||
fp8_backend=self.fp8_backend,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=self.weight_block_size,
|
||||
)
|
||||
|
||||
# Inject biases into the quant config if the model has them
|
||||
# (e.g. GPT-OSS biased MoE)
|
||||
if quant_config is not None and self.moe.has_bias:
|
||||
w13_bias = getattr(layer, "w13_bias", None)
|
||||
w2_bias = getattr(layer, "w2_bias", None)
|
||||
if w13_bias is not None:
|
||||
quant_config._w1.bias = w13_bias
|
||||
if w2_bias is not None:
|
||||
quant_config._w2.bias = w2_bias
|
||||
|
||||
return quant_config
|
||||
|
||||
@property
|
||||
def supports_eplb(self) -> bool:
|
||||
return True
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: "FusedMoE",
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.is_monolithic
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply_monolithic(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
routed_scaling_factor=layer.routed_scaling_factor,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: "FusedMoE",
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not self.is_monolithic
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
|
||||
class Fp8PerTensorOnlineMoEMethod(_Fp8OnlineMoEBase):
|
||||
"""Online tensorwise FP8 MoE quantization.
|
||||
Loads fp16/bf16 weights and quantizes them per-tensor during loading."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
layer: torch.nn.Module,
|
||||
):
|
||||
super().__init__(
|
||||
weight_block_size=None,
|
||||
layer=layer,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# TODO(@ksayers): inplace fp8 quant kernel, initialize scales with ones
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
|
||||
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
|
||||
w13_scale = torch.ones(
|
||||
layer.num_experts, device=w13.device, dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.ones(layer.num_experts, device=w2.device, dtype=torch.float32)
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
for expert in range(layer.local_num_experts):
|
||||
w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
|
||||
layer.w13_weight[expert, :, :]
|
||||
)
|
||||
w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
layer.w2_weight[expert, :, :]
|
||||
)
|
||||
|
||||
# Shuffle weights to runtime format and setup kernel.
|
||||
self._setup_kernel(
|
||||
layer,
|
||||
w13,
|
||||
w2,
|
||||
w13_scale,
|
||||
w2_scale,
|
||||
w13_input_scale=layer.w13_input_scale,
|
||||
w2_input_scale=layer.w2_input_scale,
|
||||
)
|
||||
|
||||
# Prevent duplicate processing (e.g., during weight reload)
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
|
||||
class Fp8PerBlockOnlineMoEMethod(_Fp8OnlineMoEBase):
|
||||
"""Online blockwise FP8 MoE quantization.
|
||||
Loads fp16/bf16 weights and quantizes them per-block during loading."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
layer: torch.nn.Module,
|
||||
):
|
||||
super().__init__(
|
||||
weight_block_size=[128, 128],
|
||||
layer=layer,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
|
||||
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
|
||||
|
||||
block_size = self.weight_block_size
|
||||
assert block_size is not None
|
||||
block_n, block_k = block_size
|
||||
|
||||
# Create block-shaped scales (computed here rather than in
|
||||
# create_weights because online quant doesn't need them until now).
|
||||
num_experts = layer.local_num_experts
|
||||
_, w13_out, w13_in = layer.w13_weight.shape
|
||||
_, w2_out, w2_in = layer.w2_weight.shape
|
||||
|
||||
w13_scale = torch.ones(
|
||||
num_experts,
|
||||
(w13_out + block_n - 1) // block_n,
|
||||
(w13_in + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
device=w13.device,
|
||||
)
|
||||
w2_scale = torch.ones(
|
||||
num_experts,
|
||||
(w2_out + block_n - 1) // block_n,
|
||||
(w2_in + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
device=w2.device,
|
||||
)
|
||||
|
||||
for expert in range(num_experts):
|
||||
w13[expert], w13_scale[expert] = per_block_cast_to_fp8(
|
||||
layer.w13_weight[expert],
|
||||
block_size=block_size,
|
||||
use_ue8m0=False,
|
||||
)
|
||||
w2[expert], w2_scale[expert] = per_block_cast_to_fp8(
|
||||
layer.w2_weight[expert],
|
||||
block_size=block_size,
|
||||
use_ue8m0=False,
|
||||
)
|
||||
|
||||
layer.weight_block_size = block_size
|
||||
|
||||
# Shuffle weights to runtime format and setup kernel.
|
||||
self._setup_kernel(
|
||||
layer,
|
||||
w13,
|
||||
w2,
|
||||
w13_scale,
|
||||
w2_scale,
|
||||
layer.w13_input_scale,
|
||||
layer.w2_input_scale,
|
||||
)
|
||||
|
||||
# Prevent duplicate processing (e.g., during weight reload)
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
@@ -296,6 +296,13 @@ def get_quant_config(
|
||||
)
|
||||
|
||||
if hf_quant_config is not None:
|
||||
if model_config.quantization_config is not None:
|
||||
raise ValueError(
|
||||
"Setting `quantization_config` for online "
|
||||
"quantization when the model checkpoint already "
|
||||
"has a `quantization_config` is not supported"
|
||||
)
|
||||
|
||||
# For modelopt_mixed, config.json's quantization_config may or may
|
||||
# not contain the per-layer quantized_layers map. Newer checkpoints
|
||||
# embed it directly; older ones keep it only in hf_quant_config.json.
|
||||
@@ -319,6 +326,12 @@ def get_quant_config(
|
||||
quantization_config_file = hf_overrides.get("quantization_config_file", None)
|
||||
if quantization_config_file is not None:
|
||||
if hasattr(quant_cls, "from_config_file"):
|
||||
if model_config.quantization_config is not None:
|
||||
raise ValueError(
|
||||
"Setting `quantization_config` for online "
|
||||
"quantization when the model checkpoint already "
|
||||
"has a `quantization_config` is not supported"
|
||||
)
|
||||
return quant_cls.from_config_file(quantization_config_file)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@@ -329,6 +342,12 @@ def get_quant_config(
|
||||
quantization_config_json = hf_overrides.get("quantization_config_dict_json", None)
|
||||
if quantization_config_json is not None:
|
||||
if hasattr(quant_cls, "from_config_dict_json"):
|
||||
if model_config.quantization_config is not None:
|
||||
raise ValueError(
|
||||
"Setting `quantization_config` for online "
|
||||
"quantization when the model checkpoint already "
|
||||
"has a `quantization_config` is not supported"
|
||||
)
|
||||
return quant_cls.from_config_dict_json(quantization_config_json)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@@ -337,6 +356,19 @@ def get_quant_config(
|
||||
f"{quant_cls}"
|
||||
)
|
||||
|
||||
# Online quantization doesn't read from checkpoint configs — it quantizes
|
||||
# fp16/bf16 weights on the fly during loading.
|
||||
if model_config.quantization_config is not None:
|
||||
from vllm.config.quantization import OnlineQuantizationConfigArgs
|
||||
from vllm.model_executor.layers.quantization.online.base import (
|
||||
OnlineQuantizationConfig,
|
||||
)
|
||||
|
||||
assert isinstance(
|
||||
model_config.quantization_config, OnlineQuantizationConfigArgs
|
||||
)
|
||||
return OnlineQuantizationConfig(args=model_config.quantization_config)
|
||||
|
||||
# Inflight BNB quantization
|
||||
if model_config.quantization == "bitsandbytes":
|
||||
return quant_cls.from_config({})
|
||||
|
||||
Reference in New Issue
Block a user