[ROCm][Quantization] extend AMD Quark to support mixed-precision quantized model (#24239)

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Co-authored-by: fxmarty-amd <felmarty@amd.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
xuebwang-amd
2025-11-12 01:05:22 +08:00
committed by GitHub
parent 68c09efc37
commit 05576df85c
3 changed files with 127 additions and 8 deletions

View File

@@ -114,7 +114,14 @@ class QuarkConfig(QuantizationConfig):
layer_quant_names = list(layer_quant_config.keys())
layer_quant_set = set(layer_quant_names)
if not kv_cache_set.issubset(layer_quant_set):
if not (
kv_cache_set.issubset(layer_quant_set)
or any(
fnmatch.fnmatchcase(layer_quant, pat)
for layer_quant in list(layer_quant_set)
for pat in list(kv_cache_set)
)
):
raise ValueError(
"The Quark quantized model has the "
"kv_cache_group parameter setting, "
@@ -124,10 +131,15 @@ class QuarkConfig(QuantizationConfig):
)
q_configs = [
cast(dict[str, Any], layer_quant_config.get(name))
for name in kv_cache_group
quant_cfg
for name, quant_cfg in layer_quant_config.items()
if any(fnmatch.fnmatchcase(name, pattern) for pattern in kv_cache_group)
]
if not all(deep_compare(q_config, q_configs[0]) for q_config in q_configs):
if not all(
deep_compare(q_config["output_tensors"], q_configs[0]["output_tensors"])
for q_config in q_configs
):
raise ValueError(
"The quantization method used for kv_cache should "
"be the same, but the quantization method for the "
@@ -312,9 +324,15 @@ class QuarkConfig(QuantizationConfig):
layer_quant_config = cast(
dict[str, Any], self.quant_config.get("layer_quant_config")
)
for name_pattern in layer_quant_config:
if fnmatch.fnmatch(layer_name, name_pattern):
return layer_quant_config[name_pattern]
def _matches_pattern(layer_name, pattern):
if "*" not in pattern:
return layer_name in pattern
return fnmatch.fnmatch(layer_name, pattern)
for name_pattern, config in layer_quant_config.items():
if _matches_pattern(layer_name, name_pattern):
return config
layer_type = cast(str, type(module))
layer_type_quant_config = cast(