[ROCm][Bugfix] Resolve Dynamo tracing crash from amdsmi calls in on_gfx* arch detection (#34108)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-02-09 22:50:20 -06:00
committed by GitHub
parent 047a457fa4
commit 4cde2e0159

View File

@@ -101,12 +101,10 @@ def _query_gcn_arch_from_amdsmi() -> str:
raise RuntimeError("amdsmi did not return valid GCN arch") raise RuntimeError("amdsmi did not return valid GCN arch")
@cache def _get_gcn_arch() -> str:
def _get_gcn_arch_via_amdsmi() -> str:
""" """
Get the GCN architecture name using amdsmi instead of torch.cuda. Get GCN arch via amdsmi (no CUDA init), fallback to torch.cuda.
This avoids initializing CUDA, which is important for Ray workers Called once at module level; result stored in _GCN_ARCH.
that need to set CUDA_VISIBLE_DEVICES after importing vLLM.
""" """
try: try:
return _query_gcn_arch_from_amdsmi() return _query_gcn_arch_from_amdsmi()
@@ -121,34 +119,36 @@ def _get_gcn_arch_via_amdsmi() -> str:
return torch.cuda.get_device_properties("cuda").gcnArchName return torch.cuda.get_device_properties("cuda").gcnArchName
@cache # Resolve once at module load. Uses amdsmi (no CUDA init) so Ray workers
# can still set CUDA_VISIBLE_DEVICES after import.
# These are plain Python bools — fully torch.compile/Dynamo safe.
_GCN_ARCH = _get_gcn_arch()
_ON_GFX1X = any(arch in _GCN_ARCH for arch in ["gfx11", "gfx12"])
_ON_MI3XX = any(arch in _GCN_ARCH for arch in ["gfx942", "gfx950"])
_ON_GFX9 = any(arch in _GCN_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
_ON_GFX942 = "gfx942" in _GCN_ARCH
_ON_GFX950 = "gfx950" in _GCN_ARCH
def on_gfx1x() -> bool: def on_gfx1x() -> bool:
GPU_ARCH = _get_gcn_arch_via_amdsmi() return _ON_GFX1X
return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
@cache
def on_mi3xx() -> bool: def on_mi3xx() -> bool:
GPU_ARCH = _get_gcn_arch_via_amdsmi() return _ON_MI3XX
return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])
@cache
def on_gfx9() -> bool: def on_gfx9() -> bool:
GPU_ARCH = _get_gcn_arch_via_amdsmi() return _ON_GFX9
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
@cache
def on_gfx942() -> bool: def on_gfx942() -> bool:
GPU_ARCH = _get_gcn_arch_via_amdsmi() return _ON_GFX942
return any(arch in GPU_ARCH for arch in ["gfx942"])
@cache
def on_gfx950() -> bool: def on_gfx950() -> bool:
GPU_ARCH = _get_gcn_arch_via_amdsmi() return _ON_GFX950
return any(arch in GPU_ARCH for arch in ["gfx950"])
@cache @cache
@@ -163,13 +163,9 @@ def use_rocm_custom_paged_attention(
alibi_slopes: torch.Tensor | None = None, alibi_slopes: torch.Tensor | None = None,
sinks: torch.Tensor | None = None, sinks: torch.Tensor | None = None,
) -> bool: ) -> bool:
GPU_ARCH = _get_gcn_arch_via_amdsmi()
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
# custom paged attn always supported on V0. On V1, requires sliding window # custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy. # disabled due to observed numerical discrepancy.
if ON_GFX9: if _ON_GFX9:
return ( return (
(sliding_window == 0 or sliding_window == (-1, -1)) (sliding_window == 0 or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16) and (qtype == torch.half or qtype == torch.bfloat16)
@@ -183,7 +179,7 @@ def use_rocm_custom_paged_attention(
else: else:
return ( return (
ON_GFX11_GFX12 _ON_GFX1X
and (sliding_window == 0 or sliding_window == (-1, -1)) and (sliding_window == 0 or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16) and (qtype == torch.half or qtype == torch.bfloat16)
and head_size == 128 and head_size == 128
@@ -611,18 +607,16 @@ class RocmPlatform(Platform):
@classmethod @classmethod
def supports_mx(cls) -> bool: def supports_mx(cls) -> bool:
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName return any(gfx in _GCN_ARCH for gfx in ["gfx95"])
return any(gfx in gcn_arch for gfx in ["gfx95"])
@classmethod @classmethod
def supports_fp8(cls) -> bool: def supports_fp8(cls) -> bool:
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName return any(gfx in _GCN_ARCH for gfx in ["gfx94", "gfx95", "gfx12"])
return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"])
@classmethod @classmethod
def is_fp8_fnuz(cls) -> bool: def is_fp8_fnuz(cls) -> bool:
# only device 0 is checked, this assumes MI300 platforms are homogeneous # only device 0 is checked, this assumes MI300 platforms are homogeneous
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName return "gfx94" in _GCN_ARCH
@classmethod @classmethod
def fp8_dtype(cls) -> torch.dtype: def fp8_dtype(cls) -> torch.dtype:
@@ -634,9 +628,7 @@ class RocmPlatform(Platform):
@classmethod @classmethod
def use_custom_allreduce(cls) -> bool: def use_custom_allreduce(cls) -> bool:
# We only enable custom allreduce for MI300 series # We only enable custom allreduce for MI300 series
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName return any(gfx in _GCN_ARCH for gfx in ["gfx94", "gfx95"])
supported_archs = ["gfx94", "gfx95"]
return any(gfx in gcn_arch for gfx in supported_archs)
@classmethod @classmethod
def opaque_attention_op(cls) -> bool: def opaque_attention_op(cls) -> bool:
@@ -644,7 +636,7 @@ class RocmPlatform(Platform):
@classmethod @classmethod
def is_navi(cls) -> bool: def is_navi(cls) -> bool:
return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName return "gfx1" in _GCN_ARCH
@classmethod @classmethod
def get_static_graph_wrapper_cls(cls) -> str: def get_static_graph_wrapper_cls(cls) -> str: