[Hardware][AMD][CI][Bugfix] Fix AMD Quantization test group (#31713)

Signed-off-by: Matthew Wong <Matthew.Wong2@amd.com>
This commit is contained in:
Matt
2026-01-11 01:19:46 -06:00
committed by GitHub
parent 9103ed1696
commit bde57ab2ed
12 changed files with 114 additions and 52 deletions

View File

@@ -5,6 +5,7 @@ from typing import Literal, get_args
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.platforms import current_platform
logger = init_logger(__name__)
@@ -98,6 +99,9 @@ def register_quantization_config(quantization: str):
)
else:
QUANTIZATION_METHODS.append(quantization)
# Automatically assume the custom quantization config is supported
if sq := current_platform.supported_quantization:
sq.append(quantization)
if not issubclass(quant_config_cls, QuantizationConfig):
raise ValueError(

View File

@@ -9,6 +9,9 @@ from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm
triton_scaled_mm,
)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
)
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
@@ -37,6 +40,20 @@ class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
torch.nn.Parameter(weight.t().data, requires_grad=False),
)
# WEIGHT SCALE
# Triton kernel supports only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# INPUT SCALE
if self.config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name)

View File

@@ -103,21 +103,25 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
assert layer.weight.data.dtype == torch.bfloat16, (
f"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. {str(layer.weight.data.dtype)} is specified." # noqa: E501
)
# Quantize the weights.
qweight, weight_scale = ops.scaled_fp8_quant(
layer.weight, scale=None, use_per_token_if_dynamic=True
assert layer.weight.data.dtype not in (torch.float16, torch.float32), (
"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support "
f"output dtype of bfloat16. {layer.weight.data.dtype} is specified."
)
# Update the layer with the new values.
layer.weight = Parameter(
qweight.t(), requires_grad=False
) # Pretranspose the weight
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
if layer.weight.data.dtype == torch.bfloat16:
# Quantize the weights.
qweight, weight_scale = ops.scaled_fp8_quant(
layer.weight, scale=None, use_per_token_if_dynamic=True
)
# Update the layer with the new values.
layer.weight = Parameter(
qweight.t(), requires_grad=False
) # Pretranspose the weight
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
else:
assert layer.weight.data.dtype == current_platform.fp8_dtype()
assert getattr(layer, "weight_scale", None) is not None
layer.input_scale = None
def apply(