[ROCm][Bugfix] Resolve Dynamo tracing crash from amdsmi calls in on_gfx* arch detection (#34108)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-02-09 22:50:20 -06:00
committed by GitHub
parent 047a457fa4
commit 4cde2e0159

View File

@@ -101,12 +101,10 @@ def _query_gcn_arch_from_amdsmi() -> str:
raise RuntimeError("amdsmi did not return valid GCN arch")
@cache
def _get_gcn_arch_via_amdsmi() -> str:
def _get_gcn_arch() -> str:
"""
Get the GCN architecture name using amdsmi instead of torch.cuda.
This avoids initializing CUDA, which is important for Ray workers
that need to set CUDA_VISIBLE_DEVICES after importing vLLM.
Get GCN arch via amdsmi (no CUDA init), fallback to torch.cuda.
Called once at module level; result stored in _GCN_ARCH.
"""
try:
return _query_gcn_arch_from_amdsmi()
@@ -121,34 +119,36 @@ def _get_gcn_arch_via_amdsmi() -> str:
return torch.cuda.get_device_properties("cuda").gcnArchName
@cache
# Resolve once at module load. Uses amdsmi (no CUDA init) so Ray workers
# can still set CUDA_VISIBLE_DEVICES after import.
# These are plain Python bools — fully torch.compile/Dynamo safe.
_GCN_ARCH = _get_gcn_arch()
_ON_GFX1X = any(arch in _GCN_ARCH for arch in ["gfx11", "gfx12"])
_ON_MI3XX = any(arch in _GCN_ARCH for arch in ["gfx942", "gfx950"])
_ON_GFX9 = any(arch in _GCN_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
_ON_GFX942 = "gfx942" in _GCN_ARCH
_ON_GFX950 = "gfx950" in _GCN_ARCH
def on_gfx1x() -> bool:
GPU_ARCH = _get_gcn_arch_via_amdsmi()
return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
return _ON_GFX1X
@cache
def on_mi3xx() -> bool:
GPU_ARCH = _get_gcn_arch_via_amdsmi()
return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])
return _ON_MI3XX
@cache
def on_gfx9() -> bool:
GPU_ARCH = _get_gcn_arch_via_amdsmi()
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
return _ON_GFX9
@cache
def on_gfx942() -> bool:
GPU_ARCH = _get_gcn_arch_via_amdsmi()
return any(arch in GPU_ARCH for arch in ["gfx942"])
return _ON_GFX942
@cache
def on_gfx950() -> bool:
GPU_ARCH = _get_gcn_arch_via_amdsmi()
return any(arch in GPU_ARCH for arch in ["gfx950"])
return _ON_GFX950
@cache
@@ -163,13 +163,9 @@ def use_rocm_custom_paged_attention(
alibi_slopes: torch.Tensor | None = None,
sinks: torch.Tensor | None = None,
) -> bool:
GPU_ARCH = _get_gcn_arch_via_amdsmi()
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
# custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy.
if ON_GFX9:
if _ON_GFX9:
return (
(sliding_window == 0 or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
@@ -183,7 +179,7 @@ def use_rocm_custom_paged_attention(
else:
return (
ON_GFX11_GFX12
_ON_GFX1X
and (sliding_window == 0 or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and head_size == 128
@@ -611,18 +607,16 @@ class RocmPlatform(Platform):
@classmethod
def supports_mx(cls) -> bool:
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
return any(gfx in gcn_arch for gfx in ["gfx95"])
return any(gfx in _GCN_ARCH for gfx in ["gfx95"])
@classmethod
def supports_fp8(cls) -> bool:
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"])
return any(gfx in _GCN_ARCH for gfx in ["gfx94", "gfx95", "gfx12"])
@classmethod
def is_fp8_fnuz(cls) -> bool:
# only device 0 is checked, this assumes MI300 platforms are homogeneous
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
return "gfx94" in _GCN_ARCH
@classmethod
def fp8_dtype(cls) -> torch.dtype:
@@ -634,9 +628,7 @@ class RocmPlatform(Platform):
@classmethod
def use_custom_allreduce(cls) -> bool:
# We only enable custom allreduce for MI300 series
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
supported_archs = ["gfx94", "gfx95"]
return any(gfx in gcn_arch for gfx in supported_archs)
return any(gfx in _GCN_ARCH for gfx in ["gfx94", "gfx95"])
@classmethod
def opaque_attention_op(cls) -> bool:
@@ -644,7 +636,7 @@ class RocmPlatform(Platform):
@classmethod
def is_navi(cls) -> bool:
return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName
return "gfx1" in _GCN_ARCH
@classmethod
def get_static_graph_wrapper_cls(cls) -> str: