Support FP8 block quant for CompressedTensorsW8A16Fp8 (#33280)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -651,7 +651,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
# note: input_quant will be present for converted models;
|
||||
# will be ignored during inference post loading
|
||||
return CompressedTensorsW8A16Fp8(
|
||||
strategy=weight_quant.strategy,
|
||||
weight_quant=weight_quant,
|
||||
is_static_input_scheme=not input_quant.dynamic,
|
||||
)
|
||||
|
||||
@@ -659,7 +659,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
if self._is_fp8_w8a16(weight_quant, input_quant):
|
||||
is_static_input_scheme = input_quant and not input_quant.dynamic
|
||||
return CompressedTensorsW8A16Fp8(
|
||||
strategy=weight_quant.strategy,
|
||||
weight_quant=weight_quant,
|
||||
is_static_input_scheme=is_static_input_scheme,
|
||||
)
|
||||
|
||||
|
||||
@@ -4,11 +4,17 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter,
|
||||
process_fp8_weight_block_strategy,
|
||||
validate_fp8_block_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear,
|
||||
prepare_fp8_layer_for_marlin,
|
||||
@@ -17,57 +23,40 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BlockQuantScaleParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.model_executor.utils import replace_parameter
|
||||
|
||||
__all__ = ["CompressedTensorsW8A16Fp8"]
|
||||
|
||||
SUPPORTED_STRATEGIES = [QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR]
|
||||
strategy_to_parameter_type = {
|
||||
QuantizationStrategy.BLOCK: BlockQuantScaleParameter,
|
||||
QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter,
|
||||
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
|
||||
}
|
||||
|
||||
|
||||
class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
||||
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
||||
self.strategy = strategy
|
||||
def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool):
|
||||
self.weight_quant = weight_quant
|
||||
self.strategy = weight_quant.strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.weight_block_size = self.weight_quant.block_structure
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# turing and up
|
||||
return 75
|
||||
|
||||
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
|
||||
# So if we have a fused module (QKV, MLP) with per tensor scales,
|
||||
# we expand each scale to its shard's channels.
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
if self.strategy == QuantizationStrategy.TENSOR:
|
||||
ws_channelwise = convert_to_channelwise(
|
||||
layer.weight_scale, layer.logical_widths
|
||||
)
|
||||
layer.weight_scale = torch.nn.Parameter(ws_channelwise, requires_grad=False)
|
||||
else:
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
# Weights must be transposed for marlin
|
||||
layer.weight = torch.nn.Parameter(layer.weight.t(), requires_grad=False)
|
||||
|
||||
if self.is_static_input_scheme:
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.input_scale = torch.nn.Parameter(
|
||||
layer.input_scale.data, requires_grad=False
|
||||
)
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
@@ -79,38 +68,33 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
if self.strategy == QuantizationStrategy.BLOCK:
|
||||
assert self.weight_block_size is not None
|
||||
layer.weight_block_size = self.weight_block_size
|
||||
# Validate block quantization shapes
|
||||
validate_fp8_block_shape(
|
||||
layer,
|
||||
input_size,
|
||||
output_size,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
output_partition_sizes,
|
||||
self.weight_block_size,
|
||||
)
|
||||
|
||||
# WEIGHT
|
||||
weight = create_fp8_weight_parameter(
|
||||
output_size_per_partition, input_size_per_partition, weight_loader
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
elif self.strategy == QuantizationStrategy.TENSOR:
|
||||
weight_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported weight strategy={self.strategy}, "
|
||||
f"supported strategies are {SUPPORTED_STRATEGIES}"
|
||||
)
|
||||
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
weight_scale = create_fp8_scale_parameter(
|
||||
strategy_to_parameter_type[self.strategy],
|
||||
output_partition_sizes,
|
||||
input_size_per_partition,
|
||||
layer.weight_block_size,
|
||||
weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE (to deal with converted checkpoints)
|
||||
@@ -121,6 +105,33 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
||||
)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
size_k_first = True
|
||||
# TODO(rob): refactor block quant into separate class.
|
||||
if self.strategy == QuantizationStrategy.BLOCK:
|
||||
assert self.is_static_input_scheme is False
|
||||
size_k_first = False
|
||||
weight, weight_scale = process_fp8_weight_block_strategy(
|
||||
weight, weight_scale
|
||||
)
|
||||
else:
|
||||
# Weights must be transposed for marlin
|
||||
weight = weight.t()
|
||||
if self.strategy == QuantizationStrategy.TENSOR:
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales,
|
||||
# we expand each scale to its shard's channels.
|
||||
weight_scale = convert_to_channelwise(
|
||||
weight_scale, layer.logical_widths
|
||||
)
|
||||
|
||||
# Update layer with new values
|
||||
replace_parameter(layer, "weight", weight.data)
|
||||
replace_parameter(layer, "weight_scale", weight_scale.data)
|
||||
|
||||
prepare_fp8_layer_for_marlin(layer, size_k_first=size_k_first)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
|
||||
@@ -400,7 +400,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
None,
|
||||
weight_loader,
|
||||
)
|
||||
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||
layer.register_parameter("weight_scale", scale)
|
||||
else:
|
||||
assert not self.act_q_static
|
||||
@@ -412,7 +411,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
self.weight_block_size,
|
||||
weight_loader,
|
||||
)
|
||||
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||
# The weight_scale_inv name is intentional for deepseekv3
|
||||
layer.register_parameter("weight_scale_inv", scale)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ from vllm.model_executor.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.model_executor.utils import replace_parameter
|
||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import (
|
||||
@@ -1520,6 +1520,7 @@ def create_fp8_scale_parameter(
|
||||
raise ValueError(f"Unknown parameter type: {parameter_type}")
|
||||
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||
return scale
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user