[Bugfix] fix modelopt exclude_modules name mapping (#24178)
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
@@ -45,6 +45,9 @@ from vllm.utils import next_power_of_2
|
||||
from vllm.utils.flashinfer import (flashinfer_scaled_fp4_mm, has_flashinfer,
|
||||
has_flashinfer_moe)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
QUANT_ALGOS = ["FP8", "NVFP4"]
|
||||
@@ -63,7 +66,7 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
super().__init__()
|
||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||
self.kv_cache_quant_method = kv_cache_quant_method
|
||||
self.exclude_modules = exclude_modules
|
||||
self.exclude_modules = exclude_modules or []
|
||||
if is_checkpoint_fp8_serialized:
|
||||
logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
|
||||
" the format is experimental and could change.")
|
||||
@@ -84,6 +87,11 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["hf_quant_config.json"]
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
if self.exclude_modules is not None:
|
||||
self.exclude_modules = hf_to_vllm_mapper.apply_list(
|
||||
self.exclude_modules)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
@@ -170,7 +178,9 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.is_layer_excluded(prefix):
|
||||
if (is_layer_skipped(prefix, self.exclude_modules,
|
||||
self.packed_modules_mapping)
|
||||
or self.is_layer_excluded(prefix)):
|
||||
return UnquantizedLinearMethod()
|
||||
return ModelOptFp8LinearMethod(self)
|
||||
elif isinstance(layer, Attention):
|
||||
@@ -615,6 +625,11 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["hf_quant_config.json"]
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
if self.exclude_modules is not None:
|
||||
self.exclude_modules = hf_to_vllm_mapper.apply_list(
|
||||
self.exclude_modules)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
@@ -763,7 +778,8 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
if isinstance(layer, LinearBase):
|
||||
if (is_layer_skipped(prefix, self.exclude_modules)
|
||||
if (is_layer_skipped(prefix, self.exclude_modules,
|
||||
self.packed_modules_mapping)
|
||||
or self.is_layer_excluded(prefix, self.exclude_modules)):
|
||||
return UnquantizedLinearMethod()
|
||||
return ModelOptNvFp4LinearMethod(self)
|
||||
|
||||
Reference in New Issue
Block a user