Add minimum capability requirement for AWQ (#1064)
This commit is contained in:
@@ -68,6 +68,14 @@ def get_model(model_config: ModelConfig) -> nn.Module:
|
||||
quant_config = get_quant_config(model_config.quantization,
|
||||
model_config.model,
|
||||
model_config.download_dir)
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if capability < quant_config.get_min_capability():
|
||||
raise ValueError(
|
||||
f"The quantization method {model_config.quantization} is not "
|
||||
"supported for the current GPU. "
|
||||
f"Minimum capability: {quant_config.get_min_capability()}. "
|
||||
f"Current capability: {capability}.")
|
||||
supported_dtypes = quant_config.get_supported_act_dtypes()
|
||||
if model_config.dtype not in supported_dtypes:
|
||||
raise ValueError(
|
||||
|
||||
@@ -40,6 +40,11 @@ class AWQConfig(QuantizationConfig):
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# The AWQ kernel only supports Ampere or newer GPUs.
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return [
|
||||
|
||||
@@ -15,6 +15,16 @@ class QuantizationConfig:
|
||||
"""List of supported activation dtypes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
"""Minimum GPU capability to support the quantization method.
|
||||
|
||||
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
|
||||
This requirement is due to the custom CUDA kernels used by the
|
||||
quantization method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
"""List of filenames to search for in the model directory."""
|
||||
|
||||
Reference in New Issue
Block a user