Support FP8 block quant for CompressedTensorsW8A16Fp8 (#33280)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2026-01-30 11:15:20 -05:00
committed by GitHub
parent f857a03f6b
commit fd0e377244
4 changed files with 74 additions and 64 deletions

View File

@@ -651,7 +651,7 @@ class CompressedTensorsConfig(QuantizationConfig):
# note: input_quant will be present for converted models;
# will be ignored during inference post loading
return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy,
weight_quant=weight_quant,
is_static_input_scheme=not input_quant.dynamic,
)
@@ -659,7 +659,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if self._is_fp8_w8a16(weight_quant, input_quant):
is_static_input_scheme = input_quant and not input_quant.dynamic
return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy,
weight_quant=weight_quant,
is_static_input_scheme=is_static_input_scheme,
)

View File

@@ -4,11 +4,17 @@
from collections.abc import Callable
import torch
from compressed_tensors.quantization import QuantizationStrategy
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
create_fp8_scale_parameter,
create_fp8_weight_parameter,
process_fp8_weight_block_strategy,
validate_fp8_block_shape,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
@@ -17,57 +23,40 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
)
from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
from vllm.model_executor.utils import replace_parameter
__all__ = ["CompressedTensorsW8A16Fp8"]
SUPPORTED_STRATEGIES = [QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR]
strategy_to_parameter_type = {
QuantizationStrategy.BLOCK: BlockQuantScaleParameter,
QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter,
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
}
class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool):
self.weight_quant = weight_quant
self.strategy = weight_quant.strategy
self.is_static_input_scheme = is_static_input_scheme
self.weight_block_size = self.weight_quant.block_structure
@classmethod
def get_min_capability(cls) -> int:
# turing and up
return 75
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales,
# we expand each scale to its shard's channels.
def process_weights_after_loading(self, layer) -> None:
if self.strategy == QuantizationStrategy.TENSOR:
ws_channelwise = convert_to_channelwise(
layer.weight_scale, layer.logical_widths
)
layer.weight_scale = torch.nn.Parameter(ws_channelwise, requires_grad=False)
else:
# required by torch.compile to be torch.nn.Parameter
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False
)
# Weights must be transposed for marlin
layer.weight = torch.nn.Parameter(layer.weight.t(), requires_grad=False)
if self.is_static_input_scheme:
# required by torch.compile to be torch.nn.Parameter
layer.input_scale = torch.nn.Parameter(
layer.input_scale.data, requires_grad=False
)
prepare_fp8_layer_for_marlin(layer)
def create_weights(
self,
layer: torch.nn.Module,
input_size: int,
output_partition_sizes: list[int],
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
@@ -79,38 +68,33 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# WEIGHT
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
if self.strategy == QuantizationStrategy.BLOCK:
assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size
# Validate block quantization shapes
validate_fp8_block_shape(
layer,
input_size,
output_size,
input_size_per_partition,
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
output_partition_sizes,
self.weight_block_size,
)
# WEIGHT
weight = create_fp8_weight_parameter(
output_size_per_partition, input_size_per_partition, weight_loader
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
elif self.strategy == QuantizationStrategy.TENSOR:
weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
else:
raise ValueError(
f"Unsupported weight strategy={self.strategy}, "
f"supported strategies are {SUPPORTED_STRATEGIES}"
)
weight_scale[:] = torch.finfo(torch.float32).min
weight_scale = create_fp8_scale_parameter(
strategy_to_parameter_type[self.strategy],
output_partition_sizes,
input_size_per_partition,
layer.weight_block_size,
weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE (to deal with converted checkpoints)
@@ -121,6 +105,33 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
)
layer.register_parameter("input_scale", input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = layer.weight
weight_scale = layer.weight_scale
size_k_first = True
# TODO(rob): refactor block quant into separate class.
if self.strategy == QuantizationStrategy.BLOCK:
assert self.is_static_input_scheme is False
size_k_first = False
weight, weight_scale = process_fp8_weight_block_strategy(
weight, weight_scale
)
else:
# Weights must be transposed for marlin
weight = weight.t()
if self.strategy == QuantizationStrategy.TENSOR:
# If we have a fused module (QKV, MLP) with per tensor scales,
# we expand each scale to its shard's channels.
weight_scale = convert_to_channelwise(
weight_scale, layer.logical_widths
)
# Update layer with new values
replace_parameter(layer, "weight", weight.data)
replace_parameter(layer, "weight_scale", weight_scale.data)
prepare_fp8_layer_for_marlin(layer, size_k_first=size_k_first)
def apply_weights(
self,
layer: torch.nn.Module,

View File

@@ -400,7 +400,6 @@ class Fp8LinearMethod(LinearMethodBase):
None,
weight_loader,
)
set_weight_attrs(scale, {"scale_type": "weight_scale"})
layer.register_parameter("weight_scale", scale)
else:
assert not self.act_q_static
@@ -412,7 +411,6 @@ class Fp8LinearMethod(LinearMethodBase):
self.weight_block_size,
weight_loader,
)
set_weight_attrs(scale, {"scale_type": "weight_scale"})
# The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale)

View File

@@ -29,7 +29,7 @@ from vllm.model_executor.parameter import (
ChannelQuantScaleParameter,
PerTensorScaleParameter,
)
from vllm.model_executor.utils import replace_parameter
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (
@@ -1520,6 +1520,7 @@ def create_fp8_scale_parameter(
raise ValueError(f"Unknown parameter type: {parameter_type}")
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "weight_scale"})
return scale