Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -25,9 +25,14 @@ else:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
from amdsmi import (AmdSmiException, amdsmi_get_gpu_asic_info,
|
||||
amdsmi_get_processor_handles, amdsmi_init,
|
||||
amdsmi_shut_down, amdsmi_topo_get_link_type)
|
||||
from amdsmi import (
|
||||
AmdSmiException,
|
||||
amdsmi_get_gpu_asic_info,
|
||||
amdsmi_get_processor_handles,
|
||||
amdsmi_init,
|
||||
amdsmi_shut_down,
|
||||
amdsmi_topo_get_link_type,
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from amdsmi with %r", e)
|
||||
|
||||
@@ -47,24 +52,24 @@ _ROCM_UNSUPPORTED_MODELS: list[str] = []
|
||||
|
||||
# Models partially supported by ROCm.
|
||||
# Architecture -> Reason.
|
||||
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
|
||||
"Triton flash attention. For half-precision SWA support, "
|
||||
"please use CK flash attention by setting "
|
||||
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
|
||||
_ROCM_SWA_REASON = (
|
||||
"Sliding window attention (SWA) is not yet supported in "
|
||||
"Triton flash attention. For half-precision SWA support, "
|
||||
"please use CK flash attention by setting "
|
||||
"`VLLM_USE_TRITON_FLASH_ATTN=0`"
|
||||
)
|
||||
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {
|
||||
"Qwen2ForCausalLM":
|
||||
_ROCM_SWA_REASON,
|
||||
"MistralForCausalLM":
|
||||
_ROCM_SWA_REASON,
|
||||
"MixtralForCausalLM":
|
||||
_ROCM_SWA_REASON,
|
||||
"PaliGemmaForConditionalGeneration":
|
||||
("ROCm flash attention does not yet "
|
||||
"fully support 32-bit precision on PaliGemma"),
|
||||
"Phi3VForCausalLM":
|
||||
("ROCm Triton flash attention may run into compilation errors due to "
|
||||
"excessive use of shared memory. If this happens, disable Triton FA "
|
||||
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
|
||||
"Qwen2ForCausalLM": _ROCM_SWA_REASON,
|
||||
"MistralForCausalLM": _ROCM_SWA_REASON,
|
||||
"MixtralForCausalLM": _ROCM_SWA_REASON,
|
||||
"PaliGemmaForConditionalGeneration": (
|
||||
"ROCm flash attention does not yet fully support 32-bit precision on PaliGemma"
|
||||
),
|
||||
"Phi3VForCausalLM": (
|
||||
"ROCm Triton flash attention may run into compilation errors due to "
|
||||
"excessive use of shared memory. If this happens, disable Triton FA "
|
||||
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`"
|
||||
),
|
||||
}
|
||||
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
|
||||
"0x74a0": "AMD_Instinct_MI300A",
|
||||
@@ -91,7 +96,6 @@ if "HIP_VISIBLE_DEVICES" in os.environ:
|
||||
|
||||
|
||||
def with_amdsmi_context(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
amdsmi_init()
|
||||
@@ -129,16 +133,16 @@ def on_gfx950() -> bool:
|
||||
|
||||
@cache
|
||||
def use_rocm_custom_paged_attention(
|
||||
qtype: torch.dtype,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
gqa_ratio: int,
|
||||
max_seq_len: int,
|
||||
sliding_window: int,
|
||||
kv_cache_dtype: str,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
sinks: Optional[torch.Tensor] = None) -> bool:
|
||||
|
||||
qtype: torch.dtype,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
gqa_ratio: int,
|
||||
max_seq_len: int,
|
||||
sliding_window: int,
|
||||
kv_cache_dtype: str,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
) -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
|
||||
@@ -146,26 +150,36 @@ def use_rocm_custom_paged_attention(
|
||||
# custom paged attn always supported on V0. On V1, requires sliding window
|
||||
# disabled due to observed numerical discrepancy.
|
||||
if ON_GFX9:
|
||||
return ((not envs.VLLM_USE_V1 or sliding_window == 0
|
||||
or sliding_window == (-1, -1))
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and (head_size == 64 or head_size == 128)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 128 * 1024
|
||||
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||
and envs.VLLM_ROCM_USE_AITER) and sinks is None)
|
||||
return (
|
||||
(not envs.VLLM_USE_V1 or sliding_window == 0 or sliding_window == (-1, -1))
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and (head_size == 64 or head_size == 128)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 128 * 1024
|
||||
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN and envs.VLLM_ROCM_USE_AITER)
|
||||
and sinks is None
|
||||
)
|
||||
|
||||
else:
|
||||
return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
|
||||
or sliding_window == (-1, -1))
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and head_size == 128 and block_size == 16
|
||||
and (gqa_ratio >= 3 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 128 * 1024 and alibi_slopes is None
|
||||
and kv_cache_dtype == "auto"
|
||||
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None)
|
||||
return (
|
||||
ON_GFX11_GFX12
|
||||
and (
|
||||
not envs.VLLM_USE_V1
|
||||
or sliding_window == 0
|
||||
or sliding_window == (-1, -1)
|
||||
)
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and head_size == 128
|
||||
and block_size == 16
|
||||
and (gqa_ratio >= 3 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 128 * 1024
|
||||
and alibi_slopes is None
|
||||
and kv_cache_dtype == "auto"
|
||||
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
|
||||
and sinks is None
|
||||
)
|
||||
|
||||
|
||||
class RocmPlatform(Platform):
|
||||
@@ -179,86 +193,112 @@ class RocmPlatform(Platform):
|
||||
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
|
||||
|
||||
supported_quantization: list[str] = [
|
||||
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
|
||||
"quark", "ptpc_fp8", "mxfp4", "petit_nvfp4", "torchao"
|
||||
"awq",
|
||||
"gptq",
|
||||
"fp8",
|
||||
"compressed-tensors",
|
||||
"fbgemm_fp8",
|
||||
"gguf",
|
||||
"quark",
|
||||
"ptpc_fp8",
|
||||
"mxfp4",
|
||||
"petit_nvfp4",
|
||||
"torchao",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int,
|
||||
dtype: torch.dtype) -> "_Backend":
|
||||
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
|
||||
and on_gfx9()):
|
||||
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
|
||||
return _Backend.ROCM_AITER_FA
|
||||
if on_gfx9():
|
||||
return _Backend.FLASH_ATTN
|
||||
return _Backend.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1, use_mla,
|
||||
has_sink, use_sparse) -> str:
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend,
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_v1,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
) -> str:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
|
||||
if use_sparse:
|
||||
raise NotImplementedError(
|
||||
"Sparse Attention is not supported on ROCm.")
|
||||
raise NotImplementedError("Sparse Attention is not supported on ROCm.")
|
||||
if use_mla:
|
||||
if not use_v1:
|
||||
raise RuntimeError(
|
||||
"MLA attention backends require the V1 engine. "
|
||||
"Set VLLM_USE_V1=1 to enable them.")
|
||||
"Set VLLM_USE_V1=1 to enable them."
|
||||
)
|
||||
|
||||
from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
|
||||
is_aiter_mla_enabled)
|
||||
is_aiter_mla_enabled,
|
||||
)
|
||||
|
||||
if selected_backend is None:
|
||||
selected_backend = (_Backend.ROCM_AITER_MLA if
|
||||
is_aiter_mla_enabled() or block_size == 1
|
||||
else _Backend.TRITON_MLA)
|
||||
selected_backend = (
|
||||
_Backend.ROCM_AITER_MLA
|
||||
if is_aiter_mla_enabled() or block_size == 1
|
||||
else _Backend.TRITON_MLA
|
||||
)
|
||||
|
||||
if selected_backend == _Backend.TRITON_MLA:
|
||||
if block_size != 1:
|
||||
logger.info_once("Using Triton MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"triton_mla.TritonMLABackend")
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"does not support block size {block_size}.")
|
||||
if selected_backend == _Backend.ROCM_AITER_MLA:
|
||||
if block_size == 1:
|
||||
logger.info("Using AITER MLA backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
|
||||
return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"does not support block size {block_size}."
|
||||
"(currently only supports block size 1)")
|
||||
)
|
||||
if selected_backend == _Backend.ROCM_AITER_MLA:
|
||||
if block_size == 1:
|
||||
logger.info("Using AITER MLA backend on V1 engine.")
|
||||
return (
|
||||
"vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
|
||||
)
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"does not support block size {block_size}."
|
||||
"(currently only supports block size 1)"
|
||||
)
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"is not MLA type while requested for MLA backend.")
|
||||
f"is not MLA type while requested for MLA backend."
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \
|
||||
and on_gfx9():
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
|
||||
logger.info("Using Flash Attention backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
"rocm_aiter_fa.AiterFlashAttentionBackend")
|
||||
elif (envs.VLLM_ROCM_USE_AITER and
|
||||
envs.VLLM_USE_AITER_UNIFIED_ATTENTION) or \
|
||||
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or \
|
||||
selected_backend == _Backend.ROCM_ATTN:
|
||||
return (
|
||||
"vllm.v1.attention.backends."
|
||||
"rocm_aiter_fa.AiterFlashAttentionBackend"
|
||||
)
|
||||
elif (
|
||||
(envs.VLLM_ROCM_USE_AITER and envs.VLLM_USE_AITER_UNIFIED_ATTENTION)
|
||||
or envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
|
||||
or selected_backend == _Backend.ROCM_ATTN
|
||||
):
|
||||
# rocm specific backend, with aiter and/or
|
||||
# triton prefix-prefill
|
||||
logger.info("Using Rocm/Aiter Attention backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
"rocm_attn.RocmAttentionBackend")
|
||||
return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
|
||||
else:
|
||||
# default case, using triton unified attention
|
||||
logger.info("Using Triton Attention backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
"triton_attn.TritonAttentionBackend")
|
||||
return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
|
||||
raise RuntimeError(
|
||||
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
|
||||
"to select a supported backend.")
|
||||
"to select a supported backend."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
@@ -269,9 +309,7 @@ class RocmPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=8)
|
||||
def get_device_capability(cls,
|
||||
device_id: int = 0
|
||||
) -> Optional[DeviceCapability]:
|
||||
def get_device_capability(cls, device_id: int = 0) -> Optional[DeviceCapability]:
|
||||
major, minor = torch.cuda.get_device_capability(device_id)
|
||||
return DeviceCapability(major=major, minor=minor)
|
||||
|
||||
@@ -281,21 +319,17 @@ class RocmPlatform(Platform):
|
||||
"""
|
||||
Query if the set of gpus are fully connected by xgmi (1 hop)
|
||||
"""
|
||||
handles = [
|
||||
amdsmi_get_processor_handles()[i] for i in physical_device_ids
|
||||
]
|
||||
handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
link_type = amdsmi_topo_get_link_type(
|
||||
handle, peer_handle)
|
||||
link_type = amdsmi_topo_get_link_type(handle, peer_handle)
|
||||
# type is 2 for XGMI
|
||||
if link_type["hops"] != 1 or link_type["type"] != 2:
|
||||
return False
|
||||
except AmdSmiException as error:
|
||||
logger.error("AMD 1 hop XGMI detection failed.",
|
||||
exc_info=error)
|
||||
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -326,8 +360,9 @@ class RocmPlatform(Platform):
|
||||
is_eager_execution = compilation_config == CUDAGraphMode.NONE
|
||||
|
||||
use_v1 = envs.VLLM_USE_V1
|
||||
use_aiter_rms_norm = envs.VLLM_ROCM_USE_AITER and \
|
||||
envs.VLLM_ROCM_USE_AITER_RMSNORM
|
||||
use_aiter_rms_norm = (
|
||||
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_RMSNORM
|
||||
)
|
||||
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 16
|
||||
@@ -335,21 +370,28 @@ class RocmPlatform(Platform):
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
|
||||
# Aiter rms norm perform best when CUDA Graph capture is enabled.
|
||||
if (use_v1 and use_aiter_rms_norm and not is_eager_execution
|
||||
and "-rms_norm" not in compilation_config.custom_ops):
|
||||
if (
|
||||
use_v1
|
||||
and use_aiter_rms_norm
|
||||
and not is_eager_execution
|
||||
and "-rms_norm" not in compilation_config.custom_ops
|
||||
):
|
||||
compilation_config.custom_ops.append("+rms_norm")
|
||||
|
||||
@classmethod
|
||||
def verify_model_arch(cls, model_arch: str) -> None:
|
||||
if model_arch in _ROCM_UNSUPPORTED_MODELS:
|
||||
raise ValueError(f"Model architecture '{model_arch}' is not "
|
||||
"supported by ROCm for now.")
|
||||
raise ValueError(
|
||||
f"Model architecture '{model_arch}' is not supported by ROCm for now."
|
||||
)
|
||||
|
||||
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
|
||||
msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
|
||||
logger.warning(
|
||||
"Model architecture '%s' is partially "
|
||||
"supported by ROCm: %s", model_arch, msg)
|
||||
"Model architecture '%s' is partially supported by ROCm: %s",
|
||||
model_arch,
|
||||
msg,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def verify_quantization(cls, quant: str) -> None:
|
||||
@@ -357,7 +399,8 @@ class RocmPlatform(Platform):
|
||||
if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
|
||||
logger.warning(
|
||||
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
|
||||
" is not set, enabling VLLM_USE_TRITON_AWQ.")
|
||||
" is not set, enabling VLLM_USE_TRITON_AWQ."
|
||||
)
|
||||
envs.VLLM_USE_TRITON_AWQ = True
|
||||
|
||||
@classmethod
|
||||
@@ -365,16 +408,17 @@ class RocmPlatform(Platform):
|
||||
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(cls,
|
||||
device: Optional[torch.types.Device] = None
|
||||
) -> float:
|
||||
def get_current_memory_usage(
|
||||
cls, device: Optional[torch.types.Device] = None
|
||||
) -> float:
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(
|
||||
device)[0]
|
||||
return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(device)[0]
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
|
||||
return (
|
||||
"vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def supports_mx(cls) -> bool:
|
||||
@@ -384,12 +428,12 @@ class RocmPlatform(Platform):
|
||||
@classmethod
|
||||
def supports_fp8(cls) -> bool:
|
||||
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
||||
return any(gfx in gcn_arch for gfx in ['gfx94', 'gfx95', 'gfx12'])
|
||||
return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"])
|
||||
|
||||
@classmethod
|
||||
def is_fp8_fnuz(cls) -> bool:
|
||||
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
||||
return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName
|
||||
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
||||
|
||||
@classmethod
|
||||
def fp8_dtype(cls) -> torch.dtype:
|
||||
@@ -402,7 +446,7 @@ class RocmPlatform(Platform):
|
||||
def use_custom_allreduce(cls) -> bool:
|
||||
# We only enable custom allreduce for MI300 series
|
||||
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
||||
supported_archs = ['gfx94', 'gfx95']
|
||||
supported_archs = ["gfx94", "gfx95"]
|
||||
return any(gfx in gcn_arch for gfx in supported_archs)
|
||||
|
||||
@classmethod
|
||||
@@ -411,12 +455,11 @@ class RocmPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def get_cu_count(cls, device_id: int = 0) -> int:
|
||||
return torch.cuda.get_device_properties(
|
||||
device_id).multi_processor_count
|
||||
return torch.cuda.get_device_properties(device_id).multi_processor_count
|
||||
|
||||
@classmethod
|
||||
def is_navi(cls) -> bool:
|
||||
return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName
|
||||
return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName
|
||||
|
||||
@classmethod
|
||||
def get_static_graph_wrapper_cls(cls) -> str:
|
||||
@@ -442,8 +485,9 @@ class RocmPlatform(Platform):
|
||||
backend_options = ProcessGroupNCCL.Options()
|
||||
backend_options._timeout = timeout
|
||||
|
||||
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
|
||||
backend_options)
|
||||
backend_class = ProcessGroupNCCL(
|
||||
prefix_store, group_rank, group_size, backend_options
|
||||
)
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
device = torch.device("cuda")
|
||||
pg._set_default_backend(backend_type)
|
||||
@@ -457,8 +501,9 @@ class RocmPlatform(Platform):
|
||||
return cuda_device_count_stateless()
|
||||
|
||||
@classmethod
|
||||
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
|
||||
model_config: "ModelConfig") -> bool:
|
||||
def is_kv_cache_dtype_supported(
|
||||
cls, kv_cache_dtype: str, model_config: "ModelConfig"
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@@ -479,7 +524,8 @@ class RocmPlatform(Platform):
|
||||
"with compute capability of at least 8.0. "
|
||||
f"Your {gpu_name} GPU {compute_str}. "
|
||||
"You can use float16 instead by explicitly setting the "
|
||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||
"`dtype` flag in CLI, for example: --dtype=half."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def support_hybrid_kv_cache(cls) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user