[ROCm][Quantization] Add asymmetric INT8 quantization support to TritonInt8ScaledMMLinearKernel (#38501)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
# For hf script, without -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-mmlupro-vllm-baseline.sh -m meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 -l 250 -t 8 -f 5
|
||||
model_name: "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
|
||||
required_gpu_arch:
|
||||
- gfx942
|
||||
- gfx950
|
||||
tasks:
|
||||
- name: "mmlu_pro"
|
||||
metrics:
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
# For vllm script, with -t option (tensor parallel size)
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic -l 1319 -t 1
|
||||
model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic"
|
||||
required_gpu_arch:
|
||||
- gfx942
|
||||
- gfx950
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
model_name: "Qwen/Qwen3-235B-A22B-Instruct-2507-FP8"
|
||||
required_gpu_arch:
|
||||
- gfx942
|
||||
- gfx950
|
||||
tasks:
|
||||
- name: "mmlu_pro"
|
||||
metrics:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
Qwen2.5-1.5B-Instruct.yaml
|
||||
Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
|
||||
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
||||
Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
|
||||
Qwen1.5-MoE-W4A16-compressed-tensors.yaml
|
||||
|
||||
@@ -13,6 +13,7 @@ import os
|
||||
from contextlib import contextmanager
|
||||
|
||||
import lm_eval
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
@@ -89,9 +90,40 @@ def launch_lm_eval(eval_config, tp_size):
|
||||
return results
|
||||
|
||||
|
||||
def _check_rocm_gpu_arch_requirement(eval_config):
|
||||
"""Skip the test if the model requires a ROCm GPU arch not present.
|
||||
|
||||
Model YAML configs can specify::
|
||||
|
||||
required_gpu_arch:
|
||||
- gfx942
|
||||
- gfx950
|
||||
|
||||
The check only applies on ROCm. On other platforms (e.g. CUDA) the
|
||||
field is ignored so that shared config files work for both NVIDIA and
|
||||
AMD CI pipelines.
|
||||
"""
|
||||
required_archs = eval_config.get("required_gpu_arch")
|
||||
if not required_archs:
|
||||
return
|
||||
|
||||
if not current_platform.is_rocm():
|
||||
return
|
||||
|
||||
from vllm.platforms.rocm import _GCN_ARCH # noqa: E402
|
||||
|
||||
if not any(arch in _GCN_ARCH for arch in required_archs):
|
||||
pytest.skip(
|
||||
f"Model requires GPU arch {required_archs}, "
|
||||
f"but detected arch is '{_GCN_ARCH}'"
|
||||
)
|
||||
|
||||
|
||||
def test_lm_eval_correctness_param(config_filename, tp_size):
|
||||
eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
|
||||
|
||||
_check_rocm_gpu_arch_requirement(eval_config)
|
||||
|
||||
results = launch_lm_eval(eval_config, tp_size)
|
||||
|
||||
rtol = eval_config.get("rtol", DEFAULT_RTOL)
|
||||
|
||||
@@ -2690,6 +2690,24 @@ steps:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt
|
||||
|
||||
|
||||
- label: LM Eval Small Models (MI325) # TBD
|
||||
timeout_in_minutes: 180
|
||||
mirror_hardwares: [amdexperimental, amdproduction, amdgfx942nightly, amdmi325]
|
||||
agent_pool: mi325_1
|
||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
- vllm/model_executor/models/
|
||||
- vllm/model_executor/model_loader/
|
||||
- vllm/v1/attention/backends/
|
||||
- vllm/v1/attention/selector.py
|
||||
- vllm/_aiter_ops.py
|
||||
- vllm/platforms/rocm.py
|
||||
commands:
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small-rocm.txt
|
||||
|
||||
|
||||
- label: LM Eval Small Models (B200-MI325) # TBD
|
||||
timeout_in_minutes: 180
|
||||
mirror_hardwares: [amdexperimental, amdproduction, amdgfx942nightly, amdmi325]
|
||||
|
||||
@@ -31,8 +31,6 @@ class TritonInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not c.input_symmetric:
|
||||
return False, "supports symmetric input only."
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
@@ -62,17 +60,59 @@ class TritonInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
|
||||
# INPUT SCALE
|
||||
if self.config.is_static_input_scheme:
|
||||
assert i_s is not None
|
||||
replace_parameter(
|
||||
layer,
|
||||
i_s_name,
|
||||
torch.nn.Parameter(i_s.max(), requires_grad=False),
|
||||
)
|
||||
setattr(layer, i_zp_name, None)
|
||||
|
||||
if self.config.input_symmetric:
|
||||
replace_parameter(
|
||||
layer,
|
||||
i_s_name,
|
||||
torch.nn.Parameter(i_s.max(), requires_grad=False),
|
||||
)
|
||||
setattr(layer, i_zp_name, None)
|
||||
else:
|
||||
input_zero_point = getattr(layer, i_zp_name)
|
||||
|
||||
# Reconstruct the ranges to find a single scale and azp
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
azps = input_zero_point.to(dtype=torch.int32)
|
||||
range_max = (i_s * (int8_traits.max - azps)).max()
|
||||
range_min = (i_s * (int8_traits.min - azps)).min()
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
||||
replace_parameter(
|
||||
layer,
|
||||
i_s_name,
|
||||
torch.nn.Parameter(scale, requires_grad=False),
|
||||
)
|
||||
|
||||
# AZP loaded as int8 but used as int32
|
||||
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
|
||||
replace_parameter(
|
||||
layer,
|
||||
i_zp_name,
|
||||
torch.nn.Parameter(azp, requires_grad=False),
|
||||
)
|
||||
else:
|
||||
setattr(layer, i_s_name, None)
|
||||
setattr(layer, i_zp_name, None)
|
||||
|
||||
setattr(layer, azp_adj_name, None)
|
||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||
# It does not depend on scales or azp, so it is the same for
|
||||
# static and dynamic quantization.
|
||||
# See csrc/quantization/w8a8/cutlass/Epilogues.md for the math.
|
||||
if not self.config.input_symmetric:
|
||||
weight = getattr(layer, w_q_name)
|
||||
# weight is already transposed to [K, N], sum over K (dim=0)
|
||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||
if self.config.is_static_input_scheme:
|
||||
# Fold azp into azp_adj for the per-tensor case
|
||||
azp_adj = getattr(layer, i_zp_name) * azp_adj
|
||||
setattr(
|
||||
layer,
|
||||
azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False),
|
||||
)
|
||||
else:
|
||||
setattr(layer, azp_adj_name, None)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
@@ -80,14 +120,33 @@ class TritonInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, i_s, i_zp, _ = self._get_layer_params(layer)
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
|
||||
|
||||
symmetric = azp_adj is None
|
||||
x_q, x_s, x_zp = ops.scaled_int8_quant(
|
||||
x.contiguous(), i_s, i_zp, symmetric=True
|
||||
x.contiguous(), i_s, i_zp, symmetric=symmetric
|
||||
)
|
||||
|
||||
assert x_zp is None, "Triton kernel only supports symmetric quantization"
|
||||
|
||||
return triton_scaled_mm(
|
||||
out = triton_scaled_mm(
|
||||
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
|
||||
)
|
||||
|
||||
if azp_adj is not None:
|
||||
# Asymmetric quantization: subtract the zero-point correction.
|
||||
# D = scale_a * scale_b * (A_q @ B_q - azp * azp_adj) + bias
|
||||
# triton_scaled_mm already computed scale_a * scale_b * (A_q @ B_q) + bias
|
||||
# so we subtract scale_a * scale_b * azp * azp_adj
|
||||
#
|
||||
# x_s: [M, 1] or scalar, w_s: [N, 1] or scalar, azp_adj: [1, N]
|
||||
# Reshape w_s from [N, 1] to [1, N] for proper broadcasting.
|
||||
w_s_row = w_s.view(1, -1) if w_s.dim() > 0 else w_s
|
||||
static = i_zp is not None
|
||||
if not static and x_zp is not None:
|
||||
# Dynamic per-token: azp is per-token, azp_adj is per-channel
|
||||
# x_zp: [M, 1], azp_adj: [1, N]
|
||||
out -= x_s * w_s_row * (x_zp * azp_adj).to(x.dtype)
|
||||
else:
|
||||
# Static per-tensor: azp already folded into azp_adj
|
||||
out -= (x_s * w_s_row * azp_adj).to(x.dtype)
|
||||
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user