[Bugfix] Fix the CUDA version check for FP8 support in the CUTLASS kernels (#5715)
This commit is contained in:
committed by
GitHub
parent
a7dcc62086
commit
3f3b6b2150
@@ -20,19 +20,8 @@ logger = init_logger(__name__)
|
||||
def cutlass_fp8_supported() -> bool:
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
major, minor = torch.version.cuda.split(".")
|
||||
version = int(major) * 10 + int(minor)
|
||||
|
||||
# CUTLASS FP8 kernels need at least
|
||||
# CUDA 12.0 on SM90 systems (Hopper)
|
||||
# CUDA 12.4 on SM89 systems (Lovelace)
|
||||
gpu_is_supported = False
|
||||
if capability >= 90:
|
||||
gpu_is_supported = version > 120
|
||||
elif capability >= 89:
|
||||
gpu_is_supported = version > 124
|
||||
|
||||
return gpu_is_supported
|
||||
return ops.cutlass_scaled_mm_supports_fp8(capability)
|
||||
|
||||
|
||||
class Fp8Config(QuantizationConfig):
|
||||
|
||||
Reference in New Issue
Block a user