[Bugfix] Fix the CUDA version check for FP8 support in the CUTLASS kernels (#5715)

This commit is contained in:
Tyler Michael Smith
2024-06-20 14:36:10 -04:00
committed by GitHub
parent a7dcc62086
commit 3f3b6b2150
5 changed files with 29 additions and 12 deletions

View File

@@ -20,19 +20,8 @@ logger = init_logger(__name__)
def cutlass_fp8_supported() -> bool:
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
major, minor = torch.version.cuda.split(".")
version = int(major) * 10 + int(minor)
# CUTLASS FP8 kernels need at least
# CUDA 12.0 on SM90 systems (Hopper)
# CUDA 12.4 on SM89 systems (Lovelace)
gpu_is_supported = False
if capability >= 90:
gpu_is_supported = version > 120
elif capability >= 89:
gpu_is_supported = version > 124
return gpu_is_supported
return ops.cutlass_scaled_mm_supports_fp8(capability)
class Fp8Config(QuantizationConfig):