[Quantization][FP8] Add support for FP8 models with input_scale for output projection and QK quantization (#15734)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
rasmith
2025-04-25 02:45:02 -05:00
committed by GitHub
parent 6aae216b4e
commit a41351f363
8 changed files with 105 additions and 20 deletions

View File

@@ -140,6 +140,11 @@ class Fp8Config(QuantizationConfig):
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
if name.endswith(".output_scale") and ".q_proj" in name:
return name.replace(".q_proj.output_scale", ".attn.q_scale")
if name.endswith("self_attn.prob_output_scale"):
return name.replace(".prob_output_scale", ".attn.prob_scale")
# If no matches, return None
return None

View File

@@ -38,6 +38,9 @@ class BaseKVCacheMethod(QuantizeMethodBase):
requires_grad=False)
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False)
# Initialize P = softmax(QK^T) scales
layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False)
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError(
@@ -97,5 +100,38 @@ class BaseKVCacheMethod(QuantizeMethodBase):
"may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint.")
if layer.q_scale > 0.0:
q_scale = layer.q_scale
if current_platform.is_fp8_fnuz():
q_scale *= 2
layer.calculate_kv_scales = False
else:
q_scale = 1.0
if layer.prob_scale > 0.0:
prob_scale = layer.prob_scale
if current_platform.is_fp8_fnuz():
prob_scale *= 2
else:
prob_scale = 1.0
is_singleton_float = lambda x: isinstance(x, float) or isinstance(
x, torch.Tensor) and x.numel() == 1 and x.is_floating_point()
if not is_singleton_float(q_scale) or not is_singleton_float(
prob_scale):
raise ValueError("Only support per-tensor scaling factor"
"for fp8-quantized Q/prob")
# These are used in the final Attention.forward()
layer._q_scale.copy_(q_scale)
layer._prob_scale.copy_(prob_scale)
if q_scale == 1.0 or prob_scale == 1.0:
logger.warning_once(
f"Using Q scale {q_scale} and prob scale {prob_scale} "
"with fp8 attention. This may cause accuracy issues. "
"Please make sure Q/prob scaling factors are "
"available in the fp8 checkpoint.")
del layer.k_scale
del layer.v_scale
del layer.q_scale
del layer.prob_scale

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import fnmatch
import re
from typing import Any, Dict, List, Optional, cast
import torch
@@ -125,6 +124,13 @@ class QuarkConfig(QuantizationConfig):
for q_config in q_configs:
q_config["output_tensors"] = None
# In case q_proj output is also quantized, remove the configuration
# to keep qkv consistency.
q_proj_q_config = cast(Dict[str, Any],
layer_quant_config.get("*q_proj"))
if q_proj_q_config is not None:
q_proj_q_config["output_tensors"] = None
return cls(quant_config=config,
kv_cache_group=kv_cache_group,
kv_cache_config=kv_cache_config,
@@ -289,29 +295,30 @@ class QuarkConfig(QuantizationConfig):
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if self.kv_cache_group is None or len(self.kv_cache_group) == 0:
return None
kv_proj_names = [
re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group
]
if name.endswith(".output_scale"):
if len(kv_proj_names) == 1 and kv_proj_names[0] in name:
kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale"
return name.replace(kv_output_scale_name, ".attn.k_scale")
elif len(kv_proj_names) == 2:
for kv_proj_name in kv_proj_names:
if kv_proj_name in name and kv_proj_name == "k_proj":
return name.replace(".k_proj.output_scale",
".attn.k_scale")
elif kv_proj_name in name and kv_proj_name == "v_proj":
return name.replace(".v_proj.output_scale",
".attn.v_scale")
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
if name.endswith(".output_scale") and ".q_proj" in name:
return name.replace(".q_proj.output_scale", ".attn.q_scale")
if name.endswith("self_attn.prob_output_scale"):
return name.replace(".prob_output_scale", ".attn.prob_scale")
# If no matches, return None
return None
def has_fp8_layer_weights(self):
layer_quant_config = self.quant_config.get("layer_quant_config")
to_dict = lambda obj: cast(Dict[str, Any], obj) or {}
return any([
'fp8' in cast(
str,
to_dict(
to_dict(to_dict(layer_quant_config).get(layer_name)).get(
"weight")).get("dtype"))
for layer_name in ["*v_proj", "*k_proj", "*q_proj"]
])
class QuarkLinearMethod(LinearMethodBase):