[ Misc ] Improve Min Capability Checking in compressed-tensors (#6522)

This commit is contained in:
Robert Shaw
2024-07-18 10:39:12 -04:00
committed by GitHub
parent 4634c8728b
commit 58ca663224
7 changed files with 41 additions and 8 deletions

View File

@@ -37,7 +37,7 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
return 75 return 70
def get_name(self) -> str: def get_name(self) -> str:
return "compressed_tensors" return "compressed_tensors"
@@ -85,13 +85,14 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> List[str]:
return [] return []
def _check_gptq_and_marlin_can_run(self): def _check_scheme_supported(self, min_capability: int):
capability = current_platform.get_device_capability() capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
if capability < 80: if capability < min_capability:
raise RuntimeError("The quantization config is not supported for ", raise RuntimeError(
"the current GPU. Minimum capability: 80. ", "Quantization scheme is not supported for ",
f"Current capability: {capability}.") f"the current GPU. Min capability: {min_capability}. ",
f"Current capability: {capability}.")
def _is_static_tensor_w8a8(self, weight_quant: BaseModel, def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool: input_quant: BaseModel) -> bool:
@@ -171,7 +172,6 @@ class CompressedTensorsConfig(QuantizationConfig):
# Detect If Mixed Precision # Detect If Mixed Precision
if self._is_wNa16_group_channel(weight_quant, input_quant): if self._is_wNa16_group_channel(weight_quant, input_quant):
self._check_gptq_and_marlin_can_run()
if (self.quant_format == CompressionFormat.marlin_24.value if (self.quant_format == CompressionFormat.marlin_24.value
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
return CompressedTensorsW4A16Sparse24( return CompressedTensorsW4A16Sparse24(
@@ -222,10 +222,16 @@ class CompressedTensorsConfig(QuantizationConfig):
raise ValueError( raise ValueError(
f"Could not find quantization details for {layer}.") f"Could not find quantization details for {layer}.")
return self._get_schema( scheme = self._get_schema(
weight_quant=layer_quant_details["weights"], weight_quant=layer_quant_details["weights"],
input_quant=layer_quant_details["input_activations"]) input_quant=layer_quant_details["input_activations"])
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())
return scheme
class CompressedTensorsLinearMethod(LinearMethodBase): class CompressedTensorsLinearMethod(LinearMethodBase):

View File

@@ -12,6 +12,13 @@ class CompressedTensorsScheme(ABC):
of different quantization schemes supported by CompressedTensors. of different quantization schemes supported by CompressedTensors.
""" """
@abstractmethod
def get_min_capability(self) -> int:
"""
Get minimum device capability.
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def create_weights(self, *args, **kwargs): def create_weights(self, *args, **kwargs):
""" """

View File

@@ -18,6 +18,10 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
in a linear transformation. in a linear transformation.
""" """
def get_min_capability(self) -> int:
# volta and up
return 70
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass pass

View File

@@ -29,6 +29,10 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
raise ValueError( raise ValueError(
"group_size must be given when using strategy group") "group_size must be given when using strategy group")
def get_min_capability(self) -> int:
# ampere + up
return 80
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass pass

View File

@@ -33,6 +33,10 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
"Consider quantizing with per tensor scales or upgrading " "Consider quantizing with per tensor scales or upgrading "
"to Hopper.") "to Hopper.")
def get_min_capability(self) -> int:
# lovelace and up
return 89
def process_weights_after_loading(self, layer) -> None: def process_weights_after_loading(self, layer) -> None:
# If per tensor, when we have a fused module (e.g. QKV) with per # If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel), # tensor scales (thus N scales being passed to the kernel),

View File

@@ -19,6 +19,10 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self.strategy = strategy self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
def get_min_capability(self) -> int:
# turing and up
return 75
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# WEIGHT # WEIGHT
# Cutlass kernels need transposed weight. # Cutlass kernels need transposed weight.

View File

@@ -42,6 +42,10 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
group_size=self.group_size, group_size=self.group_size,
is_sym=True) is_sym=True)
def get_min_capability(self) -> int:
# ampere and up
return 80
def create_weights(self, layer: torch.nn.Module, input_size: int, def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int], output_partition_sizes: List[int],
input_size_per_partition: int, input_size_per_partition: int,