Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -25,9 +25,14 @@ else:
logger = init_logger(__name__)
try:
from amdsmi import (AmdSmiException, amdsmi_get_gpu_asic_info,
amdsmi_get_processor_handles, amdsmi_init,
amdsmi_shut_down, amdsmi_topo_get_link_type)
from amdsmi import (
AmdSmiException,
amdsmi_get_gpu_asic_info,
amdsmi_get_processor_handles,
amdsmi_init,
amdsmi_shut_down,
amdsmi_topo_get_link_type,
)
except ImportError as e:
logger.warning("Failed to import from amdsmi with %r", e)
@@ -47,24 +52,24 @@ _ROCM_UNSUPPORTED_MODELS: list[str] = []
# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
"Triton flash attention. For half-precision SWA support, "
"please use CK flash attention by setting "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_SWA_REASON = (
"Sliding window attention (SWA) is not yet supported in "
"Triton flash attention. For half-precision SWA support, "
"please use CK flash attention by setting "
"`VLLM_USE_TRITON_FLASH_ATTN=0`"
)
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {
"Qwen2ForCausalLM":
_ROCM_SWA_REASON,
"MistralForCausalLM":
_ROCM_SWA_REASON,
"MixtralForCausalLM":
_ROCM_SWA_REASON,
"PaliGemmaForConditionalGeneration":
("ROCm flash attention does not yet "
"fully support 32-bit precision on PaliGemma"),
"Phi3VForCausalLM":
("ROCm Triton flash attention may run into compilation errors due to "
"excessive use of shared memory. If this happens, disable Triton FA "
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
"Qwen2ForCausalLM": _ROCM_SWA_REASON,
"MistralForCausalLM": _ROCM_SWA_REASON,
"MixtralForCausalLM": _ROCM_SWA_REASON,
"PaliGemmaForConditionalGeneration": (
"ROCm flash attention does not yet fully support 32-bit precision on PaliGemma"
),
"Phi3VForCausalLM": (
"ROCm Triton flash attention may run into compilation errors due to "
"excessive use of shared memory. If this happens, disable Triton FA "
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`"
),
}
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
"0x74a0": "AMD_Instinct_MI300A",
@@ -91,7 +96,6 @@ if "HIP_VISIBLE_DEVICES" in os.environ:
def with_amdsmi_context(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
amdsmi_init()
@@ -129,16 +133,16 @@ def on_gfx950() -> bool:
@cache
def use_rocm_custom_paged_attention(
qtype: torch.dtype,
head_size: int,
block_size: int,
gqa_ratio: int,
max_seq_len: int,
sliding_window: int,
kv_cache_dtype: str,
alibi_slopes: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None) -> bool:
qtype: torch.dtype,
head_size: int,
block_size: int,
gqa_ratio: int,
max_seq_len: int,
sliding_window: int,
kv_cache_dtype: str,
alibi_slopes: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
) -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
@@ -146,26 +150,36 @@ def use_rocm_custom_paged_attention(
# custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy.
if ON_GFX9:
return ((not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
and envs.VLLM_ROCM_USE_AITER) and sinks is None)
return (
(not envs.VLLM_USE_V1 or sliding_window == 0 or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN and envs.VLLM_ROCM_USE_AITER)
and sinks is None
)
else:
return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and head_size == 128 and block_size == 16
and (gqa_ratio >= 3 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024 and alibi_slopes is None
and kv_cache_dtype == "auto"
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None)
return (
ON_GFX11_GFX12
and (
not envs.VLLM_USE_V1
or sliding_window == 0
or sliding_window == (-1, -1)
)
and (qtype == torch.half or qtype == torch.bfloat16)
and head_size == 128
and block_size == 16
and (gqa_ratio >= 3 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024
and alibi_slopes is None
and kv_cache_dtype == "auto"
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
and sinks is None
)
class RocmPlatform(Platform):
@@ -179,86 +193,112 @@ class RocmPlatform(Platform):
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
"quark", "ptpc_fp8", "mxfp4", "petit_nvfp4", "torchao"
"awq",
"gptq",
"fp8",
"compressed-tensors",
"fbgemm_fp8",
"gguf",
"quark",
"ptpc_fp8",
"mxfp4",
"petit_nvfp4",
"torchao",
]
@classmethod
def get_vit_attn_backend(cls, head_size: int,
dtype: torch.dtype) -> "_Backend":
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
from vllm.attention.backends.registry import _Backend
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
and on_gfx9()):
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
return _Backend.ROCM_AITER_FA
if on_gfx9():
return _Backend.FLASH_ATTN
return _Backend.TORCH_SDPA
@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla,
has_sink, use_sparse) -> str:
def get_attn_backend_cls(
cls,
selected_backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_v1,
use_mla,
has_sink,
use_sparse,
) -> str:
from vllm.attention.backends.registry import _Backend
if use_sparse:
raise NotImplementedError(
"Sparse Attention is not supported on ROCm.")
raise NotImplementedError("Sparse Attention is not supported on ROCm.")
if use_mla:
if not use_v1:
raise RuntimeError(
"MLA attention backends require the V1 engine. "
"Set VLLM_USE_V1=1 to enable them.")
"Set VLLM_USE_V1=1 to enable them."
)
from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
is_aiter_mla_enabled)
is_aiter_mla_enabled,
)
if selected_backend is None:
selected_backend = (_Backend.ROCM_AITER_MLA if
is_aiter_mla_enabled() or block_size == 1
else _Backend.TRITON_MLA)
selected_backend = (
_Backend.ROCM_AITER_MLA
if is_aiter_mla_enabled() or block_size == 1
else _Backend.TRITON_MLA
)
if selected_backend == _Backend.TRITON_MLA:
if block_size != 1:
logger.info_once("Using Triton MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"triton_mla.TritonMLABackend")
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}.")
if selected_backend == _Backend.ROCM_AITER_MLA:
if block_size == 1:
logger.info("Using AITER MLA backend on V1 engine.")
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}."
"(currently only supports block size 1)")
)
if selected_backend == _Backend.ROCM_AITER_MLA:
if block_size == 1:
logger.info("Using AITER MLA backend on V1 engine.")
return (
"vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
)
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}."
"(currently only supports block size 1)"
)
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"is not MLA type while requested for MLA backend.")
f"is not MLA type while requested for MLA backend."
)
if envs.VLLM_USE_V1:
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \
and on_gfx9():
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
logger.info("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"rocm_aiter_fa.AiterFlashAttentionBackend")
elif (envs.VLLM_ROCM_USE_AITER and
envs.VLLM_USE_AITER_UNIFIED_ATTENTION) or \
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or \
selected_backend == _Backend.ROCM_ATTN:
return (
"vllm.v1.attention.backends."
"rocm_aiter_fa.AiterFlashAttentionBackend"
)
elif (
(envs.VLLM_ROCM_USE_AITER and envs.VLLM_USE_AITER_UNIFIED_ATTENTION)
or envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
or selected_backend == _Backend.ROCM_ATTN
):
# rocm specific backend, with aiter and/or
# triton prefix-prefill
logger.info("Using Rocm/Aiter Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"rocm_attn.RocmAttentionBackend")
return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
else:
# default case, using triton unified attention
logger.info("Using Triton Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
raise RuntimeError(
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
"to select a supported backend.")
"to select a supported backend."
)
@classmethod
def set_device(cls, device: torch.device) -> None:
@@ -269,9 +309,7 @@ class RocmPlatform(Platform):
@classmethod
@lru_cache(maxsize=8)
def get_device_capability(cls,
device_id: int = 0
) -> Optional[DeviceCapability]:
def get_device_capability(cls, device_id: int = 0) -> Optional[DeviceCapability]:
major, minor = torch.cuda.get_device_capability(device_id)
return DeviceCapability(major=major, minor=minor)
@@ -281,21 +319,17 @@ class RocmPlatform(Platform):
"""
Query if the set of gpus are fully connected by xgmi (1 hop)
"""
handles = [
amdsmi_get_processor_handles()[i] for i in physical_device_ids
]
handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
link_type = amdsmi_topo_get_link_type(
handle, peer_handle)
link_type = amdsmi_topo_get_link_type(handle, peer_handle)
# type is 2 for XGMI
if link_type["hops"] != 1 or link_type["type"] != 2:
return False
except AmdSmiException as error:
logger.error("AMD 1 hop XGMI detection failed.",
exc_info=error)
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
return False
return True
@@ -326,8 +360,9 @@ class RocmPlatform(Platform):
is_eager_execution = compilation_config == CUDAGraphMode.NONE
use_v1 = envs.VLLM_USE_V1
use_aiter_rms_norm = envs.VLLM_ROCM_USE_AITER and \
envs.VLLM_ROCM_USE_AITER_RMSNORM
use_aiter_rms_norm = (
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_RMSNORM
)
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16
@@ -335,21 +370,28 @@ class RocmPlatform(Platform):
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
# Aiter rms norm perform best when CUDA Graph capture is enabled.
if (use_v1 and use_aiter_rms_norm and not is_eager_execution
and "-rms_norm" not in compilation_config.custom_ops):
if (
use_v1
and use_aiter_rms_norm
and not is_eager_execution
and "-rms_norm" not in compilation_config.custom_ops
):
compilation_config.custom_ops.append("+rms_norm")
@classmethod
def verify_model_arch(cls, model_arch: str) -> None:
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(f"Model architecture '{model_arch}' is not "
"supported by ROCm for now.")
raise ValueError(
f"Model architecture '{model_arch}' is not supported by ROCm for now."
)
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
logger.warning(
"Model architecture '%s' is partially "
"supported by ROCm: %s", model_arch, msg)
"Model architecture '%s' is partially supported by ROCm: %s",
model_arch,
msg,
)
@classmethod
def verify_quantization(cls, quant: str) -> None:
@@ -357,7 +399,8 @@ class RocmPlatform(Platform):
if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
logger.warning(
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ.")
" is not set, enabling VLLM_USE_TRITON_AWQ."
)
envs.VLLM_USE_TRITON_AWQ = True
@classmethod
@@ -365,16 +408,17 @@ class RocmPlatform(Platform):
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
def get_current_memory_usage(
cls, device: Optional[torch.types.Device] = None
) -> float:
torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(
device)[0]
return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(device)[0]
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
return (
"vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
)
@classmethod
def supports_mx(cls) -> bool:
@@ -384,12 +428,12 @@ class RocmPlatform(Platform):
@classmethod
def supports_fp8(cls) -> bool:
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
return any(gfx in gcn_arch for gfx in ['gfx94', 'gfx95', 'gfx12'])
return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"])
@classmethod
def is_fp8_fnuz(cls) -> bool:
# only device 0 is checked, this assumes MI300 platforms are homogeneous
return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
@classmethod
def fp8_dtype(cls) -> torch.dtype:
@@ -402,7 +446,7 @@ class RocmPlatform(Platform):
def use_custom_allreduce(cls) -> bool:
# We only enable custom allreduce for MI300 series
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
supported_archs = ['gfx94', 'gfx95']
supported_archs = ["gfx94", "gfx95"]
return any(gfx in gcn_arch for gfx in supported_archs)
@classmethod
@@ -411,12 +455,11 @@ class RocmPlatform(Platform):
@classmethod
def get_cu_count(cls, device_id: int = 0) -> int:
return torch.cuda.get_device_properties(
device_id).multi_processor_count
return torch.cuda.get_device_properties(device_id).multi_processor_count
@classmethod
def is_navi(cls) -> bool:
return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName
return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName
@classmethod
def get_static_graph_wrapper_cls(cls) -> str:
@@ -442,8 +485,9 @@ class RocmPlatform(Platform):
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
backend_options)
backend_class = ProcessGroupNCCL(
prefix_store, group_rank, group_size, backend_options
)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
pg._set_default_backend(backend_type)
@@ -457,8 +501,9 @@ class RocmPlatform(Platform):
return cuda_device_count_stateless()
@classmethod
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
model_config: "ModelConfig") -> bool:
def is_kv_cache_dtype_supported(
cls, kv_cache_dtype: str, model_config: "ModelConfig"
) -> bool:
return True
@classmethod
@@ -479,7 +524,8 @@ class RocmPlatform(Platform):
"with compute capability of at least 8.0. "
f"Your {gpu_name} GPU {compute_str}. "
"You can use float16 instead by explicitly setting the "
"`dtype` flag in CLI, for example: --dtype=half.")
"`dtype` flag in CLI, for example: --dtype=half."
)
@classmethod
def support_hybrid_kv_cache(cls) -> bool: