dynamic distpatch of fp8 kernels (#14245)

Signed-off-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Jeff Daily
2025-03-11 07:54:56 -07:00
committed by GitHub
parent 08a1a1121d
commit a1c8f3796c
25 changed files with 292 additions and 159 deletions

View File

@@ -231,3 +231,20 @@ class RocmPlatform(Platform):
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
@classmethod
def supports_fp8(cls) -> bool:
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
return any(gfx in gcn_arch for gfx in ['gfx94', 'gfx95', 'gfx12'])
@classmethod
def is_fp8_fnuz(cls) -> bool:
# only device 0 is checked, this assumes MI300 platforms are homogeneous
return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName
@classmethod
def fp8_dtype(cls) -> torch.dtype:
if cls.is_fp8_fnuz():
return torch.float8_e4m3fnuz
else:
return torch.float8_e4m3fn