[Platform] Add current_platform.num_compute_units interface (#35042)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
This commit is contained in:
@@ -9,6 +9,7 @@ import torch
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
|
||||
|
||||
def cal_diff(
|
||||
@@ -124,8 +125,7 @@ def test_cutlass_mla_decode(
|
||||
q_pe = q_pe_padded
|
||||
|
||||
kv_cache_flat = blocked_k.squeeze(2)
|
||||
device_properties = torch.cuda.get_device_properties(torch.device("cuda:0"))
|
||||
sm_count = device_properties.multi_processor_count
|
||||
sm_count = num_compute_units(device.index)
|
||||
workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
|
||||
max_seqlen * block_size, b, sm_count, num_kv_splits=1
|
||||
)
|
||||
|
||||
@@ -13,6 +13,7 @@ from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
|
||||
|
||||
def is_gptq_allspark_supported(min_capability: int, max_capability: int) -> bool:
|
||||
@@ -78,7 +79,7 @@ def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype):
|
||||
if has_zp:
|
||||
zp = zp.to(dtype)
|
||||
properties = torch.cuda.get_device_properties(qw.device.index)
|
||||
sm_count = properties.multi_processor_count
|
||||
sm_count = num_compute_units(qw.device.index)
|
||||
sm_version = properties.major * 10 + properties.minor
|
||||
|
||||
n_32align = (n + 32 - 1) // 32 * 32
|
||||
|
||||
@@ -9,7 +9,7 @@ import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.rocm import on_gfx950
|
||||
from vllm.utils.platform_utils import get_cu_count
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float16]
|
||||
BIAS_MODES = [0, 1, 2]
|
||||
@@ -121,7 +121,7 @@ def pad_fp8(weight):
|
||||
@pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950")
|
||||
def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode):
|
||||
torch.manual_seed(seed)
|
||||
cu_count = get_cu_count()
|
||||
cu_count = num_compute_units()
|
||||
|
||||
# Next ^2 of n
|
||||
N_p2 = 1 << (n - 1).bit_length()
|
||||
@@ -186,7 +186,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
||||
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
cu_count = get_cu_count()
|
||||
cu_count = num_compute_units()
|
||||
|
||||
A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5
|
||||
B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5
|
||||
@@ -203,7 +203,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
||||
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
cu_count = get_cu_count()
|
||||
cu_count = num_compute_units()
|
||||
|
||||
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
|
||||
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
|
||||
@@ -222,7 +222,7 @@ def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
||||
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
cu_count = get_cu_count()
|
||||
cu_count = num_compute_units()
|
||||
|
||||
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
|
||||
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
|
||||
@@ -267,7 +267,7 @@ def test_rocm_wvsplitk_fp8_kernel(
|
||||
ref_out = torch._scaled_mm(
|
||||
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
|
||||
)
|
||||
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, get_cu_count(), BIAS)
|
||||
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, num_compute_units(), BIAS)
|
||||
|
||||
if xnorm:
|
||||
torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8)
|
||||
|
||||
@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||
check_allspark_supported_dtype_shape,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
@@ -45,7 +46,7 @@ class AllSparkLinearKernel(MPLinearKernel):
|
||||
|
||||
# prepare the parameters required for the kernel
|
||||
properties = torch.cuda.get_device_properties(device.index)
|
||||
sm_count = properties.multi_processor_count
|
||||
sm_count = num_compute_units(device.index)
|
||||
sm_version = properties.major * 10 + properties.minor
|
||||
gemm_args = {}
|
||||
gemm_args["sm_count"] = sm_count
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.platform_utils import get_cu_count
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
@@ -36,7 +36,7 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
||||
out_dtype,
|
||||
As,
|
||||
Bs,
|
||||
get_cu_count(),
|
||||
num_compute_units(),
|
||||
bias,
|
||||
)
|
||||
# Fallback
|
||||
|
||||
@@ -9,6 +9,7 @@ import torch
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
@@ -147,7 +148,7 @@ def matmul_persistent(
|
||||
assert bias is None or bias.dim() == 1, (
|
||||
"Currently assuming bias is 1D, let Horace know if you run into this"
|
||||
)
|
||||
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
NUM_SMS = num_compute_units(a.device.index)
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
dtype = a.dtype
|
||||
|
||||
@@ -13,8 +13,6 @@
|
||||
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
||||
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -22,6 +20,7 @@ from einops import rearrange
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv, next_power_of_2
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
|
||||
from .utils import input_guard
|
||||
|
||||
@@ -162,15 +161,8 @@ def layer_norm_fwd_kernel(
|
||||
tl.store(Y_base, y, mask=mask)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _get_sm_count(device: torch.device) -> int:
|
||||
"""Get and cache the SM count for a given device."""
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
return props.multi_processor_count
|
||||
|
||||
|
||||
def calc_rows_per_block(M: int, device: torch.device) -> int:
|
||||
sm_count = _get_sm_count(device)
|
||||
sm_count = num_compute_units(device.index)
|
||||
rows_per_block = next_power_of_2(cdiv(M, 2 * sm_count))
|
||||
rows_per_block = min(rows_per_block, 4)
|
||||
return rows_per_block
|
||||
|
||||
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
|
||||
from .quant_utils import pack_cols, unpack_cols
|
||||
|
||||
@@ -271,7 +272,7 @@ def marlin_make_workspace_new(
|
||||
) -> torch.Tensor:
|
||||
# In the new marlin kernel, we use the num of threadblocks as workspace
|
||||
# size. The num of threadblocks is sms_count * max_blocks_per_sm.
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
sms = num_compute_units(device.index)
|
||||
return torch.zeros(
|
||||
sms * max_blocks_per_sm, dtype=torch.int, device=device, requires_grad=False
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.utils.platform_utils import get_cu_count
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -149,7 +149,7 @@ def rocm_unquantized_gemm_impl(
|
||||
m = weight.shape[0]
|
||||
k = weight.shape[1]
|
||||
|
||||
cu_count = get_cu_count()
|
||||
cu_count = num_compute_units()
|
||||
if use_aiter_triton_gemm(n, m, k, x.dtype):
|
||||
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
|
||||
|
||||
@@ -199,7 +199,7 @@ def rocm_unquantized_gemm_impl(
|
||||
|
||||
x_view = x.reshape(-1, x.size(-1))
|
||||
if m > 8 and 0 < n <= 4:
|
||||
cu_count = get_cu_count()
|
||||
cu_count = num_compute_units()
|
||||
out = ops.wvSplitK(weight, x_view, cu_count, bias)
|
||||
return out.reshape(*x.shape[:-1], weight.shape[0])
|
||||
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
|
||||
|
||||
@@ -26,6 +26,7 @@ from vllm.utils.deep_gemm import (
|
||||
m_grouped_fp8_gemm_nt_contiguous,
|
||||
)
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
|
||||
|
||||
def _generate_optimal_warmup_m_values(
|
||||
@@ -44,7 +45,7 @@ def _generate_optimal_warmup_m_values(
|
||||
# DeepGEMM's possible block sizes
|
||||
block_ms = [64, 128, 256]
|
||||
block_ns = list(range(16, min(257, n + 1), 16))
|
||||
num_sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
num_sms = num_compute_units(device.index)
|
||||
|
||||
m_values = set()
|
||||
|
||||
|
||||
@@ -538,6 +538,10 @@ class CudaPlatformBase(Platform):
|
||||
def support_static_graph_mode(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def num_compute_units(cls, device_id=0):
|
||||
return torch.cuda.get_device_properties(device_id).multi_processor_count
|
||||
|
||||
|
||||
# NVML utils
|
||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||
|
||||
@@ -692,6 +692,16 @@ class Platform:
|
||||
"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def num_compute_units(cls, device_id: int = 0) -> int:
|
||||
"""
|
||||
Get the number of compute units for the current platform.
|
||||
(NVIDIA SM / AMD CU / Intel EU)
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"num_compute_units is not implemented for the current platform."
|
||||
)
|
||||
|
||||
|
||||
class UnspecifiedPlatform(Platform):
|
||||
_enum = PlatformEnum.UNSPECIFIED
|
||||
|
||||
@@ -682,3 +682,7 @@ class RocmPlatform(Platform):
|
||||
@classmethod
|
||||
def support_static_graph_mode(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def num_compute_units(cls, device_id=0):
|
||||
return torch.cuda.get_device_properties(device_id).multi_processor_count
|
||||
|
||||
@@ -277,3 +277,7 @@ class XPUPlatform(Platform):
|
||||
"""Copy blocks from XPU to host (CPU)."""
|
||||
_src_cache = src_cache[:, src_block_indices]
|
||||
dst_cache[:, dst_block_indices] = _src_cache.cpu()
|
||||
|
||||
@classmethod
|
||||
def num_compute_units(cls, device_id: int = 0) -> int:
|
||||
return torch.xpu.get_device_properties(device_id).max_compute_units
|
||||
|
||||
@@ -24,11 +24,6 @@ def xpu_is_initialized() -> bool:
|
||||
return torch.xpu.is_initialized()
|
||||
|
||||
|
||||
def get_cu_count(device_id: int = 0) -> int:
|
||||
"""Returns the total number of compute units (CU) on single GPU."""
|
||||
return torch.cuda.get_device_properties(device_id).multi_processor_count
|
||||
|
||||
|
||||
def cuda_get_device_properties(
|
||||
device, names: Sequence[str], init_cuda=False
|
||||
) -> tuple[Any, ...]:
|
||||
@@ -57,3 +52,11 @@ def is_uva_available() -> bool:
|
||||
# UVA requires pinned memory.
|
||||
# TODO: Add more requirements for UVA if needed.
|
||||
return is_pin_memory_available()
|
||||
|
||||
|
||||
@cache
|
||||
def num_compute_units(device_id: int = 0) -> int:
|
||||
"""Get the number of compute units of the current device."""
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return current_platform.num_compute_units(device_id)
|
||||
|
||||
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonMetadataBuilder,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionCGSupport,
|
||||
AttentionLayer,
|
||||
@@ -74,8 +75,7 @@ class SM100Workspace:
|
||||
|
||||
# Pre-compute sm_count to avoid recomputing it. Use device 0 as a proxy
|
||||
# (assumes all devices are similar)
|
||||
properties = torch.cuda.get_device_properties(torch.device("cuda:0"))
|
||||
self._sm_count = properties.multi_processor_count
|
||||
self._sm_count = num_compute_units(0)
|
||||
|
||||
def get_buf(self):
|
||||
return self._workspace_buf
|
||||
|
||||
@@ -21,6 +21,7 @@ from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionCGSupport,
|
||||
AttentionLayer,
|
||||
@@ -130,8 +131,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
self.cg_buf_num_splits = None
|
||||
self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8")
|
||||
|
||||
device_properties = torch.cuda.get_device_properties(self.device)
|
||||
num_sms = device_properties.multi_processor_count
|
||||
num_sms = num_compute_units(self.device.index)
|
||||
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.cg_buf_tile_scheduler_metadata = torch.zeros(
|
||||
|
||||
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
@@ -237,8 +238,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
||||
# DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2)
|
||||
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
sm_count = props.multi_processor_count
|
||||
sm_count = num_compute_units(device.index)
|
||||
|
||||
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
|
||||
@@ -9,6 +9,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, has_deep_gemm
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
@@ -219,8 +220,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
)
|
||||
self.reorder_batch_threshold += self.num_speculative_tokens
|
||||
|
||||
props = torch.cuda.get_device_properties(self.device)
|
||||
sm_count = props.multi_processor_count
|
||||
sm_count = num_compute_units(self.device.index)
|
||||
self.num_sms = sm_count
|
||||
|
||||
self.decode_lens_buffer = torch.empty(
|
||||
|
||||
@@ -13,7 +13,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import get_cu_count
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
@@ -38,7 +38,7 @@ if current_platform.is_rocm():
|
||||
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
|
||||
|
||||
def num_programs(total_tokens):
|
||||
return min(total_tokens, get_cu_count())
|
||||
return min(total_tokens, num_compute_units())
|
||||
|
||||
@triton.jit
|
||||
def cp_mha_gather_cache_kernel(
|
||||
|
||||
@@ -13,6 +13,7 @@ import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import next_power_of_2
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
|
||||
_TRITON_TABLE_CACHE: dict[tuple[torch.device], tuple[torch.Tensor, torch.Tensor]] = {}
|
||||
_TRITON_BUFFER_CACHE: dict[tuple[torch.device, torch.dtype, int], torch.Tensor] = {}
|
||||
@@ -988,7 +989,7 @@ def apply_top_k_top_p_triton(
|
||||
else:
|
||||
p_ptr = logits # Dummy pointer (won't be read)
|
||||
|
||||
num_sm = torch.cuda.get_device_properties(logits.device).multi_processor_count
|
||||
num_sm = num_compute_units(logits.device.index)
|
||||
NUM_PROGRAMS = min(num_sm, batch_size)
|
||||
|
||||
# Cache per-Triton Program buffer on each device.
|
||||
|
||||
@@ -98,7 +98,7 @@ from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
|
||||
from vllm.utils.nvtx_pytorch_hooks import PytHooks
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.utils.platform_utils import is_pin_memory_available, num_compute_units
|
||||
from vllm.utils.torch_utils import (
|
||||
get_dtype_size,
|
||||
kv_cache_dtype_str_to_dtype,
|
||||
@@ -909,8 +909,8 @@ class GPUModelRunner(
|
||||
# Note: used for model runner override.
|
||||
def _init_device_properties(self) -> None:
|
||||
"""Initialize attributes from torch.cuda.get_device_properties"""
|
||||
self.device_properties = torch.cuda.get_device_properties(self.device)
|
||||
self.num_sms = self.device_properties.multi_processor_count
|
||||
|
||||
self.num_sms = num_compute_units(self.device.index)
|
||||
|
||||
# Note: used for model runner override.
|
||||
def _sync_device(self) -> None:
|
||||
|
||||
@@ -23,6 +23,7 @@ from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -72,8 +73,7 @@ class SMControlContextManager:
|
||||
"SM control is currently only supported on CUDA"
|
||||
)
|
||||
|
||||
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||
total_sms = props.multi_processor_count
|
||||
total_sms = num_compute_units(torch.cuda.current_device().index)
|
||||
|
||||
assert comm_sms < total_sms
|
||||
self.total_sms = total_sms
|
||||
|
||||
@@ -28,9 +28,6 @@ class XPUModelRunner(GPUModelRunner):
|
||||
# FIXME: To be verified.
|
||||
self.cascade_attn_enabled = False
|
||||
|
||||
def _init_device_properties(self) -> None:
|
||||
self.num_sms = None
|
||||
|
||||
def _sync_device(self) -> None:
|
||||
torch.xpu.synchronize()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user