From 4cde2e015944495e6bd650a4415cfb342bd73cfb Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Mon, 9 Feb 2026 22:50:20 -0600 Subject: [PATCH] [ROCm][Bugfix] Resolve Dynamo tracing crash from amdsmi calls in on_gfx* arch detection (#34108) Signed-off-by: Andreas Karatzas --- vllm/platforms/rocm.py | 62 ++++++++++++++++++------------------------ 1 file changed, 27 insertions(+), 35 deletions(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 2545e4620..b463c80a1 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -101,12 +101,10 @@ def _query_gcn_arch_from_amdsmi() -> str: raise RuntimeError("amdsmi did not return valid GCN arch") -@cache -def _get_gcn_arch_via_amdsmi() -> str: +def _get_gcn_arch() -> str: """ - Get the GCN architecture name using amdsmi instead of torch.cuda. - This avoids initializing CUDA, which is important for Ray workers - that need to set CUDA_VISIBLE_DEVICES after importing vLLM. + Get GCN arch via amdsmi (no CUDA init), fallback to torch.cuda. + Called once at module level; result stored in _GCN_ARCH. """ try: return _query_gcn_arch_from_amdsmi() @@ -121,34 +119,36 @@ def _get_gcn_arch_via_amdsmi() -> str: return torch.cuda.get_device_properties("cuda").gcnArchName -@cache +# Resolve once at module load. Uses amdsmi (no CUDA init) so Ray workers +# can still set CUDA_VISIBLE_DEVICES after import. +# These are plain Python bools — fully torch.compile/Dynamo safe. +_GCN_ARCH = _get_gcn_arch() + +_ON_GFX1X = any(arch in _GCN_ARCH for arch in ["gfx11", "gfx12"]) +_ON_MI3XX = any(arch in _GCN_ARCH for arch in ["gfx942", "gfx950"]) +_ON_GFX9 = any(arch in _GCN_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) +_ON_GFX942 = "gfx942" in _GCN_ARCH +_ON_GFX950 = "gfx950" in _GCN_ARCH + + def on_gfx1x() -> bool: - GPU_ARCH = _get_gcn_arch_via_amdsmi() - return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"]) + return _ON_GFX1X -@cache def on_mi3xx() -> bool: - GPU_ARCH = _get_gcn_arch_via_amdsmi() - return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"]) + return _ON_MI3XX -@cache def on_gfx9() -> bool: - GPU_ARCH = _get_gcn_arch_via_amdsmi() - return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) + return _ON_GFX9 -@cache def on_gfx942() -> bool: - GPU_ARCH = _get_gcn_arch_via_amdsmi() - return any(arch in GPU_ARCH for arch in ["gfx942"]) + return _ON_GFX942 -@cache def on_gfx950() -> bool: - GPU_ARCH = _get_gcn_arch_via_amdsmi() - return any(arch in GPU_ARCH for arch in ["gfx950"]) + return _ON_GFX950 @cache @@ -163,13 +163,9 @@ def use_rocm_custom_paged_attention( alibi_slopes: torch.Tensor | None = None, sinks: torch.Tensor | None = None, ) -> bool: - GPU_ARCH = _get_gcn_arch_via_amdsmi() - ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) - ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"]) - # custom paged attn always supported on V0. On V1, requires sliding window # disabled due to observed numerical discrepancy. - if ON_GFX9: + if _ON_GFX9: return ( (sliding_window == 0 or sliding_window == (-1, -1)) and (qtype == torch.half or qtype == torch.bfloat16) @@ -183,7 +179,7 @@ def use_rocm_custom_paged_attention( else: return ( - ON_GFX11_GFX12 + _ON_GFX1X and (sliding_window == 0 or sliding_window == (-1, -1)) and (qtype == torch.half or qtype == torch.bfloat16) and head_size == 128 @@ -611,18 +607,16 @@ class RocmPlatform(Platform): @classmethod def supports_mx(cls) -> bool: - gcn_arch = torch.cuda.get_device_properties(0).gcnArchName - return any(gfx in gcn_arch for gfx in ["gfx95"]) + return any(gfx in _GCN_ARCH for gfx in ["gfx95"]) @classmethod def supports_fp8(cls) -> bool: - gcn_arch = torch.cuda.get_device_properties(0).gcnArchName - return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"]) + return any(gfx in _GCN_ARCH for gfx in ["gfx94", "gfx95", "gfx12"]) @classmethod def is_fp8_fnuz(cls) -> bool: # only device 0 is checked, this assumes MI300 platforms are homogeneous - return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName + return "gfx94" in _GCN_ARCH @classmethod def fp8_dtype(cls) -> torch.dtype: @@ -634,9 +628,7 @@ class RocmPlatform(Platform): @classmethod def use_custom_allreduce(cls) -> bool: # We only enable custom allreduce for MI300 series - gcn_arch = torch.cuda.get_device_properties(0).gcnArchName - supported_archs = ["gfx94", "gfx95"] - return any(gfx in gcn_arch for gfx in supported_archs) + return any(gfx in _GCN_ARCH for gfx in ["gfx94", "gfx95"]) @classmethod def opaque_attention_op(cls) -> bool: @@ -644,7 +636,7 @@ class RocmPlatform(Platform): @classmethod def is_navi(cls) -> bool: - return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName + return "gfx1" in _GCN_ARCH @classmethod def get_static_graph_wrapper_cls(cls) -> str: