[Bugfix] Fix gptq failure on T4s (#7264)

This commit is contained in:
Lucas Wilkinson
2024-08-07 16:05:37 -04:00
committed by GitHub
parent 469b3bc538
commit 311f743831
3 changed files with 14 additions and 15 deletions

View File

@@ -26,12 +26,13 @@ USE_FP32_REDUCE_DEFAULT = True
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
def query_marlin_supported_quant_types(has_zp: bool,
min_capability: Optional[int] = None):
if min_capability is None:
device_capability: Optional[int] = None
):
if device_capability is None:
major, minor = current_platform.get_device_capability()
min_capability = major * 10 + minor
device_capability = major * 10 + minor
if min_capability < 80:
if device_capability < 80:
return []
if has_zp:
@@ -48,20 +49,20 @@ def _check_marlin_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
min_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
if min_capability is None:
if device_capability is None:
major, minor = current_platform.get_device_capability()
min_capability = major * 10 + minor
device_capability = major * 10 + minor
supported_types = query_marlin_supported_quant_types(
has_zp, min_capability)
has_zp, device_capability)
if quant_type not in supported_types:
return (False, f"Marlin does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"min_capability = {min_capability}, zp = {has_zp}).")
f"device_capability = {device_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
return (False, f"Marlin does not support group_size = {group_size}. "
f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
@@ -73,9 +74,9 @@ def _check_marlin_supported(
def check_marlin_supported(quant_type: ScalarType,
group_size: int,
has_zp: bool = False,
min_capability: Optional[int] = None) -> bool:
device_capability: Optional[int] = None) -> bool:
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
min_capability)
device_capability)
return cond