dynamic distpatch of fp8 kernels (#14245)
Signed-off-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
@@ -32,11 +32,8 @@ def scaled_mm_torch(a: torch.Tensor,
|
||||
|
||||
def get_8bit_types():
|
||||
types = [torch.int8]
|
||||
supports_fp8 = current_platform.has_device_capability(89)
|
||||
if current_platform.is_rocm() and supports_fp8:
|
||||
types.append(torch.float8_e4m3fnuz)
|
||||
elif current_platform.is_cuda() and supports_fp8:
|
||||
types.append(torch.float8_e4m3fn)
|
||||
if current_platform.supports_fp8():
|
||||
types.append(current_platform.fp8_dtype())
|
||||
return types
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user