dynamic distpatch of fp8 kernels (#14245)
Signed-off-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
@@ -231,3 +231,20 @@ class RocmPlatform(Platform):
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
|
||||
|
||||
@classmethod
|
||||
def supports_fp8(cls) -> bool:
|
||||
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
||||
return any(gfx in gcn_arch for gfx in ['gfx94', 'gfx95', 'gfx12'])
|
||||
|
||||
@classmethod
|
||||
def is_fp8_fnuz(cls) -> bool:
|
||||
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
||||
return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName
|
||||
|
||||
@classmethod
|
||||
def fp8_dtype(cls) -> torch.dtype:
|
||||
if cls.is_fp8_fnuz():
|
||||
return torch.float8_e4m3fnuz
|
||||
else:
|
||||
return torch.float8_e4m3fn
|
||||
|
||||
Reference in New Issue
Block a user