[ROCm][Bugfix] Resolve Dynamo tracing crash from amdsmi calls in on_gfx* arch detection (#34108)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -101,12 +101,10 @@ def _query_gcn_arch_from_amdsmi() -> str:
|
|||||||
raise RuntimeError("amdsmi did not return valid GCN arch")
|
raise RuntimeError("amdsmi did not return valid GCN arch")
|
||||||
|
|
||||||
|
|
||||||
@cache
|
def _get_gcn_arch() -> str:
|
||||||
def _get_gcn_arch_via_amdsmi() -> str:
|
|
||||||
"""
|
"""
|
||||||
Get the GCN architecture name using amdsmi instead of torch.cuda.
|
Get GCN arch via amdsmi (no CUDA init), fallback to torch.cuda.
|
||||||
This avoids initializing CUDA, which is important for Ray workers
|
Called once at module level; result stored in _GCN_ARCH.
|
||||||
that need to set CUDA_VISIBLE_DEVICES after importing vLLM.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return _query_gcn_arch_from_amdsmi()
|
return _query_gcn_arch_from_amdsmi()
|
||||||
@@ -121,34 +119,36 @@ def _get_gcn_arch_via_amdsmi() -> str:
|
|||||||
return torch.cuda.get_device_properties("cuda").gcnArchName
|
return torch.cuda.get_device_properties("cuda").gcnArchName
|
||||||
|
|
||||||
|
|
||||||
@cache
|
# Resolve once at module load. Uses amdsmi (no CUDA init) so Ray workers
|
||||||
|
# can still set CUDA_VISIBLE_DEVICES after import.
|
||||||
|
# These are plain Python bools — fully torch.compile/Dynamo safe.
|
||||||
|
_GCN_ARCH = _get_gcn_arch()
|
||||||
|
|
||||||
|
_ON_GFX1X = any(arch in _GCN_ARCH for arch in ["gfx11", "gfx12"])
|
||||||
|
_ON_MI3XX = any(arch in _GCN_ARCH for arch in ["gfx942", "gfx950"])
|
||||||
|
_ON_GFX9 = any(arch in _GCN_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||||
|
_ON_GFX942 = "gfx942" in _GCN_ARCH
|
||||||
|
_ON_GFX950 = "gfx950" in _GCN_ARCH
|
||||||
|
|
||||||
|
|
||||||
def on_gfx1x() -> bool:
|
def on_gfx1x() -> bool:
|
||||||
GPU_ARCH = _get_gcn_arch_via_amdsmi()
|
return _ON_GFX1X
|
||||||
return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
|
|
||||||
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def on_mi3xx() -> bool:
|
def on_mi3xx() -> bool:
|
||||||
GPU_ARCH = _get_gcn_arch_via_amdsmi()
|
return _ON_MI3XX
|
||||||
return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])
|
|
||||||
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def on_gfx9() -> bool:
|
def on_gfx9() -> bool:
|
||||||
GPU_ARCH = _get_gcn_arch_via_amdsmi()
|
return _ON_GFX9
|
||||||
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
|
||||||
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def on_gfx942() -> bool:
|
def on_gfx942() -> bool:
|
||||||
GPU_ARCH = _get_gcn_arch_via_amdsmi()
|
return _ON_GFX942
|
||||||
return any(arch in GPU_ARCH for arch in ["gfx942"])
|
|
||||||
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def on_gfx950() -> bool:
|
def on_gfx950() -> bool:
|
||||||
GPU_ARCH = _get_gcn_arch_via_amdsmi()
|
return _ON_GFX950
|
||||||
return any(arch in GPU_ARCH for arch in ["gfx950"])
|
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
@@ -163,13 +163,9 @@ def use_rocm_custom_paged_attention(
|
|||||||
alibi_slopes: torch.Tensor | None = None,
|
alibi_slopes: torch.Tensor | None = None,
|
||||||
sinks: torch.Tensor | None = None,
|
sinks: torch.Tensor | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
GPU_ARCH = _get_gcn_arch_via_amdsmi()
|
|
||||||
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
|
||||||
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
|
|
||||||
|
|
||||||
# custom paged attn always supported on V0. On V1, requires sliding window
|
# custom paged attn always supported on V0. On V1, requires sliding window
|
||||||
# disabled due to observed numerical discrepancy.
|
# disabled due to observed numerical discrepancy.
|
||||||
if ON_GFX9:
|
if _ON_GFX9:
|
||||||
return (
|
return (
|
||||||
(sliding_window == 0 or sliding_window == (-1, -1))
|
(sliding_window == 0 or sliding_window == (-1, -1))
|
||||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||||
@@ -183,7 +179,7 @@ def use_rocm_custom_paged_attention(
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
return (
|
return (
|
||||||
ON_GFX11_GFX12
|
_ON_GFX1X
|
||||||
and (sliding_window == 0 or sliding_window == (-1, -1))
|
and (sliding_window == 0 or sliding_window == (-1, -1))
|
||||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||||
and head_size == 128
|
and head_size == 128
|
||||||
@@ -611,18 +607,16 @@ class RocmPlatform(Platform):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def supports_mx(cls) -> bool:
|
def supports_mx(cls) -> bool:
|
||||||
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
return any(gfx in _GCN_ARCH for gfx in ["gfx95"])
|
||||||
return any(gfx in gcn_arch for gfx in ["gfx95"])
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def supports_fp8(cls) -> bool:
|
def supports_fp8(cls) -> bool:
|
||||||
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
return any(gfx in _GCN_ARCH for gfx in ["gfx94", "gfx95", "gfx12"])
|
||||||
return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"])
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_fp8_fnuz(cls) -> bool:
|
def is_fp8_fnuz(cls) -> bool:
|
||||||
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
||||||
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
return "gfx94" in _GCN_ARCH
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def fp8_dtype(cls) -> torch.dtype:
|
def fp8_dtype(cls) -> torch.dtype:
|
||||||
@@ -634,9 +628,7 @@ class RocmPlatform(Platform):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def use_custom_allreduce(cls) -> bool:
|
def use_custom_allreduce(cls) -> bool:
|
||||||
# We only enable custom allreduce for MI300 series
|
# We only enable custom allreduce for MI300 series
|
||||||
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
return any(gfx in _GCN_ARCH for gfx in ["gfx94", "gfx95"])
|
||||||
supported_archs = ["gfx94", "gfx95"]
|
|
||||||
return any(gfx in gcn_arch for gfx in supported_archs)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def opaque_attention_op(cls) -> bool:
|
def opaque_attention_op(cls) -> bool:
|
||||||
@@ -644,7 +636,7 @@ class RocmPlatform(Platform):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_navi(cls) -> bool:
|
def is_navi(cls) -> bool:
|
||||||
return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName
|
return "gfx1" in _GCN_ARCH
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_static_graph_wrapper_cls(cls) -> str:
|
def get_static_graph_wrapper_cls(cls) -> str:
|
||||||
|
|||||||
Reference in New Issue
Block a user