Improve configs - ModelConfig (#17130)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
@@ -85,7 +86,7 @@ class GPTQMarlin24Config(QuantizationConfig):
|
||||
self.quant_type, self.group_size)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq_marlin_24"
|
||||
|
||||
@classmethod
|
||||
@@ -108,8 +109,8 @@ class GPTQMarlin24Config(QuantizationConfig):
|
||||
return cls(weight_bits, group_size)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
is_marlin_24_format = (
|
||||
hf_quant_cfg.get("checkpoint_format") == "marlin_24")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user