[Frontend] new online quantization frontend (#38138)

Signed-off-by: Vasiliy Kuznetsov <vasiliy@meta.com>
This commit is contained in:
Vasiliy Kuznetsov
2026-04-03 11:58:39 -04:00
committed by GitHub
parent 97f92c6b47
commit 7b1a7423be
13 changed files with 1205 additions and 0 deletions

View File

@@ -33,6 +33,13 @@ QuantizationMethods = Literal[
"mxfp8",
"petit_nvfp4",
"cpu_awq",
"online",
# Below are values of the OnlineQuantScheme enum, specified as strings to
# avoid circular import issues. This is here to provide a shortcut where
# the user can specify "LLM(..., quantization='fp8_per_tensor')" as
# shorthand for creating a more complicated online quant config object
"fp8_per_tensor",
"fp8_per_block",
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
@@ -103,6 +110,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
raise ValueError(f"Invalid quantization method: {quantization}")
# lazy import to avoid triggering `torch.compile` too early
from vllm.config.quantization import OnlineQuantScheme
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
from .awq import AWQConfig
@@ -129,6 +137,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config
from .mxfp8 import Mxfp8Config
from .online.base import OnlineQuantizationConfig
from .petit import PetitNvFp4Config
from .torchao import TorchAOConfig
@@ -157,7 +166,20 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"mxfp8": Mxfp8Config,
"petit_nvfp4": PetitNvFp4Config,
"cpu_awq": CPUAWQConfig,
"online": OnlineQuantizationConfig,
}
# Below are values of the OnlineQuantScheme enum. This is here to provide
# a shortcut where the user can specify
# "LLM(..., quantization='fp8_per_tensor')" as shorthand for creating a
# more complicated online quant config object
for scheme in OnlineQuantScheme:
assert scheme.value not in method_to_config, (
f"Online quant scheme {scheme.value!r} conflicts with an "
f"existing quantization method"
)
method_to_config[scheme.value] = OnlineQuantizationConfig
# Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)

View File

@@ -497,6 +497,8 @@ class Fp8LinearMethod(LinearMethodBase):
return self.fp8_linear.apply_weights(layer, x, bias)
# TODO(future PR): remove this class in favor of
# online/fp8.py::Fp8PerTensorOnlineLinearMethod
class Fp8OnlineLinearMethod(Fp8LinearMethod):
"""Online version of Fp8LinearMethod which loads a full precision checkpoint
and quantizes weights during loading."""
@@ -919,6 +921,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
# TODO(future PR): remove this class in favor of
# online/fp8.py::Fp8PerTensorOnlineMoEMethod
class Fp8OnlineMoEMethod(Fp8MoEMethod):
"""MoE method for online FP8 quantization.
Supports loading quantized FP16/BF16 model checkpoints with dynamic

View File

@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

View File

@@ -0,0 +1,116 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
from vllm.config.quantization import (
OnlineQuantizationConfigArgs,
OnlineQuantScheme,
)
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
)
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.linear import (
LinearBase,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
should_ignore_layer,
)
from vllm.model_executor.layers.quantization.online.fp8 import (
Fp8PerBlockOnlineLinearMethod,
Fp8PerBlockOnlineMoEMethod,
Fp8PerTensorOnlineLinearMethod,
Fp8PerTensorOnlineMoEMethod,
)
class OnlineQuantizationConfig(QuantizationConfig):
"""Model-level config class for online quantization (quantize fp16/bf16 weights
during model loading, without requiring a pre-quantized checkpoint)."""
def __init__(
self,
args: OnlineQuantizationConfigArgs,
) -> None:
super().__init__()
if (
args.global_scheme is None
and args.linear_scheme_override is None
and args.moe_scheme_override is None
):
raise ValueError(
"OnlineQuantizationConfig requires at least one of "
"global_scheme, linear_scheme_override, or "
"moe_scheme_override to be set."
)
self.args = args
self.quant_scheme = args.global_scheme
self.ignored_layers: list[str] = args.ignore
@classmethod
def get_name(cls) -> QuantizationMethods:
return "online"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
# Note: as more online quant schemes will be added, this
# value will become the minimum across all supported schemes.
return 75
@classmethod
def get_config_filenames(cls) -> list[str]:
return []
@classmethod
def from_config(cls, config: dict[str, Any]) -> "OnlineQuantizationConfig":
raise NotImplementedError(
"OnlineQuantizationConfig does not support loading from a "
"checkpoint config. Use quantization_config or "
"quantization='fp8_per_tensor'/'fp8_per_block' instead."
)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase):
if should_ignore_layer(
prefix,
ignore=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedLinearMethod()
linear_scheme = self.args.linear_scheme_override or self.args.global_scheme
if linear_scheme == OnlineQuantScheme.FP8_PER_BLOCK:
return Fp8PerBlockOnlineLinearMethod()
else:
return Fp8PerTensorOnlineLinearMethod()
elif isinstance(layer, FusedMoE):
if should_ignore_layer(
prefix,
ignore=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedFusedMoEMethod(layer.moe_config)
moe_scheme = self.args.moe_scheme_override or self.args.global_scheme
if moe_scheme == OnlineQuantScheme.FP8_PER_BLOCK:
return Fp8PerBlockOnlineMoEMethod(layer=layer)
else:
return Fp8PerTensorOnlineMoEMethod(layer=layer)
return None

View File

@@ -0,0 +1,632 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
import torch
from torch.nn import Module
if TYPE_CHECKING:
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.oracle.fp8 import Fp8MoeBackend
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.kernels.linear import init_fp8_linear_kernel
from vllm.model_executor.layers.fused_moe import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
select_fp8_moe_backend,
)
from vllm.model_executor.layers.linear import (
LinearMethodBase,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
maybe_post_process_fp8_weight_block,
process_fp8_weight_block_strategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8Static128BlockSym,
kFp8StaticTensorSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
cutlass_fp8_supported,
)
from vllm.model_executor.model_loader.reload.layerwise import (
initialize_online_processing,
)
from vllm.model_executor.parameter import ModelWeightParameter
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported, per_block_cast_to_fp8
# ---------------------------------------------------------------------------
# Online FP8 Linear Methods
# ---------------------------------------------------------------------------
class _Fp8OnlineLinearBase(LinearMethodBase):
"""Shared base for online FP8 linear methods. Loads fp16/bf16 checkpoint
weights onto meta device and materializes them just-in-time."""
uses_meta_device: bool = True
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
device="meta", # materialized and processed during loading
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
initialize_online_processing(layer)
class Fp8PerTensorOnlineLinearMethod(_Fp8OnlineLinearBase):
"""Online tensorwise FP8 linear quantization.
Loads fp16/bf16 weights and quantizes them per-tensor during loading."""
def __init__(self):
self.out_dtype = torch.get_default_dtype()
# Use per-token quantization for better perf if dynamic and cutlass
if cutlass_fp8_supported():
activation_quant_key = kFp8DynamicTokenSym
else:
activation_quant_key = kFp8DynamicTensorSym
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key,
weight_quant_key=kFp8StaticTensorSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
layer.input_scale = None
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
# Update layer with new values.
replace_parameter(layer, "weight", qweight.t().data)
replace_parameter(layer, "weight_scale", weight_scale.data)
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# if batch invariant mode is enabled, use BF16 dequant
if envs.VLLM_BATCH_INVARIANT:
weight_fp8 = layer.weight.to(torch.bfloat16)
weight_scale = layer.weight_scale.to(torch.bfloat16)
if weight_scale.numel() == 1:
# Per-tensor: simple scalar multiplication
weight_bf16 = weight_fp8 * weight_scale
else:
# Multiple scales (fused modules like QKV)
if (
weight_scale.dim() == 1
and weight_scale.shape[0] == weight_fp8.shape[0]
):
# Per-row scaling
weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
else:
# Fallback
weight_bf16 = weight_fp8 * weight_scale
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
return self.fp8_linear.apply_weights(layer, x, bias)
class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase):
"""Online blockwise FP8 linear quantization.
Loads fp16/bf16 weights and quantizes them per-block during loading."""
def __init__(self):
self.out_dtype = torch.get_default_dtype()
self.weight_block_size = [128, 128]
self.use_deep_gemm = is_deep_gemm_supported()
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
use_deep_gemm=self.use_deep_gemm,
)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
super().create_weights(
layer,
input_size_per_partition,
output_partition_sizes,
input_size,
output_size,
params_dtype,
**extra_weight_attrs,
)
layer.weight_block_size = self.weight_block_size
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
layer.input_scale = None
block_size = self.weight_block_size
qweight, weight_scale_inv = per_block_cast_to_fp8(
layer.weight, block_size=block_size, use_ue8m0=False
)
qweight, weight_scale_inv = process_fp8_weight_block_strategy(
qweight, weight_scale_inv
)
replace_parameter(layer, "weight", qweight.data)
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
maybe_post_process_fp8_weight_block(layer)
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.weight_block_size is not None
# Note: batch invariance already handled in the function below
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
bias=bias,
)
# ---------------------------------------------------------------------------
# Online FP8 MoE Methods
# ---------------------------------------------------------------------------
class _Fp8OnlineMoEBase(FusedMoEMethodBase):
"""Shared base for online FP8 MoE methods. Loads fp16/bf16 checkpoint
weights onto meta device and materializes them just-in-time."""
uses_meta_device: bool = True
# Declared here for mypy; actual values are set in __init__.
fp8_backend: "Fp8MoeBackend"
experts_cls: "type[mk.FusedMoEExperts] | None"
weight_scale_name: str
weight_block_size: list[int] | None
moe: "FusedMoEConfig"
is_monolithic: bool
moe_quant_config: "FusedMoEQuantConfig | None"
moe_kernel: "mk.FusedMoEKernel | None"
def __init__(
self,
*,
weight_block_size: list[int] | None,
layer: torch.nn.Module,
):
super().__init__(layer.moe_config)
self.weight_block_size = weight_block_size
self.block_quant: bool = self.weight_block_size is not None
self.weight_scale_name = (
"weight_scale_inv" if self.block_quant else "weight_scale"
)
# Set weight key and activation key for kernel compatibility
if self.block_quant:
weight_key = kFp8Static128BlockSym
activation_key = kFp8Dynamic128Sym
else:
weight_key = kFp8StaticTensorSym
activation_key = kFp8DynamicTensorSym
# Select Fp8 MoE backend
self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
config=self.moe,
weight_key=weight_key,
activation_key=activation_key,
allow_vllm_cutlass=False,
)
def create_weights(
self,
layer: Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
device="meta",
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
device="meta", # materialized and processed during loading
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# BIASES (for models like GPT-OSS that have biased MoE)
if self.moe.has_bias:
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
device="meta", # materialized and processed during loading
dtype=layer.orig_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
w2_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
device="meta", # materialized and processed during loading
dtype=layer.orig_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
layer.w13_input_scale = None
layer.w2_input_scale = None
initialize_online_processing(layer)
def _setup_kernel(
self,
layer: "FusedMoE",
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
w13_input_scale: torch.Tensor | None,
w2_input_scale: torch.Tensor | None,
) -> None:
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
)
# Shuffle weights to runtime format.
w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
fp8_backend=self.fp8_backend,
layer=layer,
w13=w13,
w2=w2,
w13_scale=w13_scale,
w2_scale=w2_scale,
w13_input_scale=w13_input_scale,
w2_input_scale=w2_input_scale,
)
# Replace parameters with updated versions. Note that this helper
# function ensures the replacement is compatible with RL weight reloads.
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w2_weight", w2)
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_kernel = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> "mk.FusedMoEPrepareAndFinalizeModular | None":
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel "
"initialization logic. This function should not be called."
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> "FusedMoEQuantConfig":
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
make_fp8_moe_quant_config,
)
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
a1_scale = layer.w13_input_scale
a2_scale = layer.w2_input_scale
quant_config = make_fp8_moe_quant_config(
fp8_backend=self.fp8_backend,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=self.weight_block_size,
)
# Inject biases into the quant config if the model has them
# (e.g. GPT-OSS biased MoE)
if quant_config is not None and self.moe.has_bias:
w13_bias = getattr(layer, "w13_bias", None)
w2_bias = getattr(layer, "w2_bias", None)
if w13_bias is not None:
quant_config._w1.bias = w13_bias
if w2_bias is not None:
quant_config._w2.bias = w2_bias
return quant_config
@property
def supports_eplb(self) -> bool:
return True
def apply_monolithic(
self,
layer: "FusedMoE",
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
)
def apply(
self,
layer: "FusedMoE",
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)
class Fp8PerTensorOnlineMoEMethod(_Fp8OnlineMoEBase):
"""Online tensorwise FP8 MoE quantization.
Loads fp16/bf16 weights and quantizes them per-tensor during loading."""
def __init__(
self,
*,
layer: torch.nn.Module,
):
super().__init__(
weight_block_size=None,
layer=layer,
)
def process_weights_after_loading(self, layer: Module) -> None:
# TODO(@ksayers): inplace fp8 quant kernel, initialize scales with ones
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# If checkpoint is fp16, quantize in place.
fp8_dtype = current_platform.fp8_dtype()
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
w13_scale = torch.ones(
layer.num_experts, device=w13.device, dtype=torch.float32
)
w2_scale = torch.ones(layer.num_experts, device=w2.device, dtype=torch.float32)
layer.w13_input_scale = None
layer.w2_input_scale = None
for expert in range(layer.local_num_experts):
w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
layer.w13_weight[expert, :, :]
)
w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant(
layer.w2_weight[expert, :, :]
)
# Shuffle weights to runtime format and setup kernel.
self._setup_kernel(
layer,
w13,
w2,
w13_scale,
w2_scale,
w13_input_scale=layer.w13_input_scale,
w2_input_scale=layer.w2_input_scale,
)
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True
class Fp8PerBlockOnlineMoEMethod(_Fp8OnlineMoEBase):
"""Online blockwise FP8 MoE quantization.
Loads fp16/bf16 weights and quantizes them per-block during loading."""
def __init__(
self,
*,
layer: torch.nn.Module,
):
super().__init__(
weight_block_size=[128, 128],
layer=layer,
)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
fp8_dtype = current_platform.fp8_dtype()
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
block_size = self.weight_block_size
assert block_size is not None
block_n, block_k = block_size
# Create block-shaped scales (computed here rather than in
# create_weights because online quant doesn't need them until now).
num_experts = layer.local_num_experts
_, w13_out, w13_in = layer.w13_weight.shape
_, w2_out, w2_in = layer.w2_weight.shape
w13_scale = torch.ones(
num_experts,
(w13_out + block_n - 1) // block_n,
(w13_in + block_k - 1) // block_k,
dtype=torch.float32,
device=w13.device,
)
w2_scale = torch.ones(
num_experts,
(w2_out + block_n - 1) // block_n,
(w2_in + block_k - 1) // block_k,
dtype=torch.float32,
device=w2.device,
)
for expert in range(num_experts):
w13[expert], w13_scale[expert] = per_block_cast_to_fp8(
layer.w13_weight[expert],
block_size=block_size,
use_ue8m0=False,
)
w2[expert], w2_scale[expert] = per_block_cast_to_fp8(
layer.w2_weight[expert],
block_size=block_size,
use_ue8m0=False,
)
layer.weight_block_size = block_size
# Shuffle weights to runtime format and setup kernel.
self._setup_kernel(
layer,
w13,
w2,
w13_scale,
w2_scale,
layer.w13_input_scale,
layer.w2_input_scale,
)
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True

View File

@@ -296,6 +296,13 @@ def get_quant_config(
)
if hf_quant_config is not None:
if model_config.quantization_config is not None:
raise ValueError(
"Setting `quantization_config` for online "
"quantization when the model checkpoint already "
"has a `quantization_config` is not supported"
)
# For modelopt_mixed, config.json's quantization_config may or may
# not contain the per-layer quantized_layers map. Newer checkpoints
# embed it directly; older ones keep it only in hf_quant_config.json.
@@ -319,6 +326,12 @@ def get_quant_config(
quantization_config_file = hf_overrides.get("quantization_config_file", None)
if quantization_config_file is not None:
if hasattr(quant_cls, "from_config_file"):
if model_config.quantization_config is not None:
raise ValueError(
"Setting `quantization_config` for online "
"quantization when the model checkpoint already "
"has a `quantization_config` is not supported"
)
return quant_cls.from_config_file(quantization_config_file)
else:
raise NotImplementedError(
@@ -329,6 +342,12 @@ def get_quant_config(
quantization_config_json = hf_overrides.get("quantization_config_dict_json", None)
if quantization_config_json is not None:
if hasattr(quant_cls, "from_config_dict_json"):
if model_config.quantization_config is not None:
raise ValueError(
"Setting `quantization_config` for online "
"quantization when the model checkpoint already "
"has a `quantization_config` is not supported"
)
return quant_cls.from_config_dict_json(quantization_config_json)
else:
raise NotImplementedError(
@@ -337,6 +356,19 @@ def get_quant_config(
f"{quant_cls}"
)
# Online quantization doesn't read from checkpoint configs — it quantizes
# fp16/bf16 weights on the fly during loading.
if model_config.quantization_config is not None:
from vllm.config.quantization import OnlineQuantizationConfigArgs
from vllm.model_executor.layers.quantization.online.base import (
OnlineQuantizationConfig,
)
assert isinstance(
model_config.quantization_config, OnlineQuantizationConfigArgs
)
return OnlineQuantizationConfig(args=model_config.quantization_config)
# Inflight BNB quantization
if model_config.quantization == "bitsandbytes":
return quant_cls.from_config({})