[ROCm][Quantization] add apply_vllm_mapper in quark config for models like gpt-oss (#28638)
Signed-off-by: xuebwang-amd <xuebwang@amd.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.quantization.quark.utils import (
|
|||||||
deep_compare,
|
deep_compare,
|
||||||
should_ignore_layer,
|
should_ignore_layer,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.models.utils import WeightsMapper
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -57,7 +58,6 @@ class QuarkConfig(QuantizationConfig):
|
|||||||
self.kv_cache_group = kv_cache_group
|
self.kv_cache_group = kv_cache_group
|
||||||
self.kv_cache_config = kv_cache_config
|
self.kv_cache_config = kv_cache_config
|
||||||
self.pack_method = pack_method
|
self.pack_method = pack_method
|
||||||
self.ignore: list[str] = cast(list[str], self.quant_config.get("exclude", []))
|
|
||||||
|
|
||||||
def get_linear_method(self) -> "QuarkLinearMethod":
|
def get_linear_method(self) -> "QuarkLinearMethod":
|
||||||
return QuarkLinearMethod(self)
|
return QuarkLinearMethod(self)
|
||||||
@@ -72,14 +72,42 @@ class QuarkConfig(QuantizationConfig):
|
|||||||
def get_name(self) -> QuantizationMethods:
|
def get_name(self) -> QuantizationMethods:
|
||||||
return "quark"
|
return "quark"
|
||||||
|
|
||||||
|
def apply_vllm_mapper( # noqa: B027
|
||||||
|
self, hf_to_vllm_mapper: "WeightsMapper"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Interface for models to update module names referenced in
|
||||||
|
quantization configs in order to reflect the vllm model structure
|
||||||
|
|
||||||
|
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
|
||||||
|
structure of the qconfig) to vllm model structure
|
||||||
|
"""
|
||||||
|
quant_config_with_hf_to_vllm_mapper = {}
|
||||||
|
|
||||||
|
for k, v in self.quant_config.items():
|
||||||
|
if isinstance(v, list):
|
||||||
|
quant_config_with_hf_to_vllm_mapper[k] = hf_to_vllm_mapper.apply_list(v)
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
quant_config_with_hf_to_vllm_mapper[k] = hf_to_vllm_mapper.apply_dict(v)
|
||||||
|
else:
|
||||||
|
if isinstance(v, str):
|
||||||
|
mapped_v_list = hf_to_vllm_mapper.apply_list([v])
|
||||||
|
if mapped_v_list:
|
||||||
|
quant_config_with_hf_to_vllm_mapper[k] = mapped_v_list[0]
|
||||||
|
else:
|
||||||
|
quant_config_with_hf_to_vllm_mapper[k] = v
|
||||||
|
|
||||||
|
self.quant_config = quant_config_with_hf_to_vllm_mapper
|
||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
from vllm.attention.layer import Attention # Avoid circular import
|
from vllm.attention.layer import Attention # Avoid circular import
|
||||||
|
|
||||||
# Check if the layer is skipped for quantization.
|
# Check if the layer is skipped for quantization.
|
||||||
|
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
|
||||||
if should_ignore_layer(
|
if should_ignore_layer(
|
||||||
prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
|
prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
|
||||||
):
|
):
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
@@ -93,9 +121,6 @@ class QuarkConfig(QuantizationConfig):
|
|||||||
return QuarkMoEMethod.get_moe_method(self, module=layer, layer_name=prefix)
|
return QuarkMoEMethod.get_moe_method(self, module=layer, layer_name=prefix)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
|
||||||
self.ignore = hf_to_vllm_mapper.apply_list(self.ignore)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: dict[str, Any]) -> "QuarkConfig":
|
def from_config(cls, config: dict[str, Any]) -> "QuarkConfig":
|
||||||
export_config = config.get("export")
|
export_config = config.get("export")
|
||||||
|
|||||||
Reference in New Issue
Block a user