[ROCm][Quantization] extend AMD Quark to support mixed-precision quantized model (#24239)
Signed-off-by: xuebwang-amd <xuebwang@amd.com> Co-authored-by: fxmarty-amd <felmarty@amd.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -114,7 +114,14 @@ class QuarkConfig(QuantizationConfig):
|
||||
layer_quant_names = list(layer_quant_config.keys())
|
||||
layer_quant_set = set(layer_quant_names)
|
||||
|
||||
if not kv_cache_set.issubset(layer_quant_set):
|
||||
if not (
|
||||
kv_cache_set.issubset(layer_quant_set)
|
||||
or any(
|
||||
fnmatch.fnmatchcase(layer_quant, pat)
|
||||
for layer_quant in list(layer_quant_set)
|
||||
for pat in list(kv_cache_set)
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"The Quark quantized model has the "
|
||||
"kv_cache_group parameter setting, "
|
||||
@@ -124,10 +131,15 @@ class QuarkConfig(QuantizationConfig):
|
||||
)
|
||||
|
||||
q_configs = [
|
||||
cast(dict[str, Any], layer_quant_config.get(name))
|
||||
for name in kv_cache_group
|
||||
quant_cfg
|
||||
for name, quant_cfg in layer_quant_config.items()
|
||||
if any(fnmatch.fnmatchcase(name, pattern) for pattern in kv_cache_group)
|
||||
]
|
||||
if not all(deep_compare(q_config, q_configs[0]) for q_config in q_configs):
|
||||
|
||||
if not all(
|
||||
deep_compare(q_config["output_tensors"], q_configs[0]["output_tensors"])
|
||||
for q_config in q_configs
|
||||
):
|
||||
raise ValueError(
|
||||
"The quantization method used for kv_cache should "
|
||||
"be the same, but the quantization method for the "
|
||||
@@ -312,9 +324,15 @@ class QuarkConfig(QuantizationConfig):
|
||||
layer_quant_config = cast(
|
||||
dict[str, Any], self.quant_config.get("layer_quant_config")
|
||||
)
|
||||
for name_pattern in layer_quant_config:
|
||||
if fnmatch.fnmatch(layer_name, name_pattern):
|
||||
return layer_quant_config[name_pattern]
|
||||
|
||||
def _matches_pattern(layer_name, pattern):
|
||||
if "*" not in pattern:
|
||||
return layer_name in pattern
|
||||
return fnmatch.fnmatch(layer_name, pattern)
|
||||
|
||||
for name_pattern, config in layer_quant_config.items():
|
||||
if _matches_pattern(layer_name, name_pattern):
|
||||
return config
|
||||
|
||||
layer_type = cast(str, type(module))
|
||||
layer_type_quant_config = cast(
|
||||
|
||||
Reference in New Issue
Block a user