[Distributed] [ROCM] Fix custom allreduce enable checks (#16010)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
This commit is contained in:
Ilya Markov
2025-04-04 18:39:08 +02:00
committed by GitHub
parent 2386803f2a
commit ef608c37a7
4 changed files with 21 additions and 4 deletions

View File

@@ -302,3 +302,10 @@ class RocmPlatform(Platform):
def supports_v1(cls, model_config: ModelConfig) -> bool:
# V1 support on AMD gpus is experimental
return True
@classmethod
def use_custom_allreduce(cls) -> bool:
# We only enable custom allreduce for MI300 series
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
supported_archs = ['gfx94']
return any(gfx in gcn_arch for gfx in supported_archs)