[ROCm][Bugfix] Resolve Dynamo tracing crash from amdsmi calls in on_gfx* arch detection (#34108)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -101,12 +101,10 @@ def _query_gcn_arch_from_amdsmi() -> str:
|
||||
raise RuntimeError("amdsmi did not return valid GCN arch")
|
||||
|
||||
|
||||
@cache
|
||||
def _get_gcn_arch_via_amdsmi() -> str:
|
||||
def _get_gcn_arch() -> str:
|
||||
"""
|
||||
Get the GCN architecture name using amdsmi instead of torch.cuda.
|
||||
This avoids initializing CUDA, which is important for Ray workers
|
||||
that need to set CUDA_VISIBLE_DEVICES after importing vLLM.
|
||||
Get GCN arch via amdsmi (no CUDA init), fallback to torch.cuda.
|
||||
Called once at module level; result stored in _GCN_ARCH.
|
||||
"""
|
||||
try:
|
||||
return _query_gcn_arch_from_amdsmi()
|
||||
@@ -121,34 +119,36 @@ def _get_gcn_arch_via_amdsmi() -> str:
|
||||
return torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
|
||||
|
||||
@cache
|
||||
# Resolve once at module load. Uses amdsmi (no CUDA init) so Ray workers
|
||||
# can still set CUDA_VISIBLE_DEVICES after import.
|
||||
# These are plain Python bools — fully torch.compile/Dynamo safe.
|
||||
_GCN_ARCH = _get_gcn_arch()
|
||||
|
||||
_ON_GFX1X = any(arch in _GCN_ARCH for arch in ["gfx11", "gfx12"])
|
||||
_ON_MI3XX = any(arch in _GCN_ARCH for arch in ["gfx942", "gfx950"])
|
||||
_ON_GFX9 = any(arch in _GCN_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||
_ON_GFX942 = "gfx942" in _GCN_ARCH
|
||||
_ON_GFX950 = "gfx950" in _GCN_ARCH
|
||||
|
||||
|
||||
def on_gfx1x() -> bool:
|
||||
GPU_ARCH = _get_gcn_arch_via_amdsmi()
|
||||
return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
|
||||
return _ON_GFX1X
|
||||
|
||||
|
||||
@cache
|
||||
def on_mi3xx() -> bool:
|
||||
GPU_ARCH = _get_gcn_arch_via_amdsmi()
|
||||
return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])
|
||||
return _ON_MI3XX
|
||||
|
||||
|
||||
@cache
|
||||
def on_gfx9() -> bool:
|
||||
GPU_ARCH = _get_gcn_arch_via_amdsmi()
|
||||
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||
return _ON_GFX9
|
||||
|
||||
|
||||
@cache
|
||||
def on_gfx942() -> bool:
|
||||
GPU_ARCH = _get_gcn_arch_via_amdsmi()
|
||||
return any(arch in GPU_ARCH for arch in ["gfx942"])
|
||||
return _ON_GFX942
|
||||
|
||||
|
||||
@cache
|
||||
def on_gfx950() -> bool:
|
||||
GPU_ARCH = _get_gcn_arch_via_amdsmi()
|
||||
return any(arch in GPU_ARCH for arch in ["gfx950"])
|
||||
return _ON_GFX950
|
||||
|
||||
|
||||
@cache
|
||||
@@ -163,13 +163,9 @@ def use_rocm_custom_paged_attention(
|
||||
alibi_slopes: torch.Tensor | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
) -> bool:
|
||||
GPU_ARCH = _get_gcn_arch_via_amdsmi()
|
||||
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
|
||||
|
||||
# custom paged attn always supported on V0. On V1, requires sliding window
|
||||
# disabled due to observed numerical discrepancy.
|
||||
if ON_GFX9:
|
||||
if _ON_GFX9:
|
||||
return (
|
||||
(sliding_window == 0 or sliding_window == (-1, -1))
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
@@ -183,7 +179,7 @@ def use_rocm_custom_paged_attention(
|
||||
|
||||
else:
|
||||
return (
|
||||
ON_GFX11_GFX12
|
||||
_ON_GFX1X
|
||||
and (sliding_window == 0 or sliding_window == (-1, -1))
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and head_size == 128
|
||||
@@ -611,18 +607,16 @@ class RocmPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def supports_mx(cls) -> bool:
|
||||
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
||||
return any(gfx in gcn_arch for gfx in ["gfx95"])
|
||||
return any(gfx in _GCN_ARCH for gfx in ["gfx95"])
|
||||
|
||||
@classmethod
|
||||
def supports_fp8(cls) -> bool:
|
||||
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
||||
return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"])
|
||||
return any(gfx in _GCN_ARCH for gfx in ["gfx94", "gfx95", "gfx12"])
|
||||
|
||||
@classmethod
|
||||
def is_fp8_fnuz(cls) -> bool:
|
||||
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
||||
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
||||
return "gfx94" in _GCN_ARCH
|
||||
|
||||
@classmethod
|
||||
def fp8_dtype(cls) -> torch.dtype:
|
||||
@@ -634,9 +628,7 @@ class RocmPlatform(Platform):
|
||||
@classmethod
|
||||
def use_custom_allreduce(cls) -> bool:
|
||||
# We only enable custom allreduce for MI300 series
|
||||
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
||||
supported_archs = ["gfx94", "gfx95"]
|
||||
return any(gfx in gcn_arch for gfx in supported_archs)
|
||||
return any(gfx in _GCN_ARCH for gfx in ["gfx94", "gfx95"])
|
||||
|
||||
@classmethod
|
||||
def opaque_attention_op(cls) -> bool:
|
||||
@@ -644,7 +636,7 @@ class RocmPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def is_navi(cls) -> bool:
|
||||
return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName
|
||||
return "gfx1" in _GCN_ARCH
|
||||
|
||||
@classmethod
|
||||
def get_static_graph_wrapper_cls(cls) -> str:
|
||||
|
||||
Reference in New Issue
Block a user