diff --git a/tests/kernels/attention/test_cutlass_mla_decode.py b/tests/kernels/attention/test_cutlass_mla_decode.py index 784c16304..1f2fb66b3 100644 --- a/tests/kernels/attention/test_cutlass_mla_decode.py +++ b/tests/kernels/attention/test_cutlass_mla_decode.py @@ -9,6 +9,7 @@ import torch import vllm._custom_ops as ops from vllm.platforms import current_platform from vllm.triton_utils import triton +from vllm.utils.platform_utils import num_compute_units def cal_diff( @@ -124,8 +125,7 @@ def test_cutlass_mla_decode( q_pe = q_pe_padded kv_cache_flat = blocked_k.squeeze(2) - device_properties = torch.cuda.get_device_properties(torch.device("cuda:0")) - sm_count = device_properties.multi_processor_count + sm_count = num_compute_units(device.index) workspace_size = ops.sm100_cutlass_mla_get_workspace_size( max_seqlen * block_size, b, sm_count, num_kv_splits=1 ) diff --git a/tests/kernels/quantization/test_allspark_gemm.py b/tests/kernels/quantization/test_allspark_gemm.py index e5f056f04..7f6adbd52 100644 --- a/tests/kernels/quantization/test_allspark_gemm.py +++ b/tests/kernels/quantization/test_allspark_gemm.py @@ -13,6 +13,7 @@ from vllm.model_executor.layers.quantization.utils.allspark_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights from vllm.platforms import current_platform from vllm.scalar_type import scalar_types +from vllm.utils.platform_utils import num_compute_units def is_gptq_allspark_supported(min_capability: int, max_capability: int) -> bool: @@ -78,7 +79,7 @@ def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype): if has_zp: zp = zp.to(dtype) properties = torch.cuda.get_device_properties(qw.device.index) - sm_count = properties.multi_processor_count + sm_count = num_compute_units(qw.device.index) sm_version = properties.major * 10 + properties.minor n_32align = (n + 32 - 1) // 32 * 32 diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 2564f1829..e67772616 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -9,7 +9,7 @@ import vllm._custom_ops as ops from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant from vllm.platforms import current_platform from vllm.platforms.rocm import on_gfx950 -from vllm.utils.platform_utils import get_cu_count +from vllm.utils.platform_utils import num_compute_units DTYPES = [torch.bfloat16, torch.float16] BIAS_MODES = [0, 1, 2] @@ -121,7 +121,7 @@ def pad_fp8(weight): @pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950") def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode): torch.manual_seed(seed) - cu_count = get_cu_count() + cu_count = num_compute_units() # Next ^2 of n N_p2 = 1 << (n - 1).bit_length() @@ -186,7 +186,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed): @pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): torch.manual_seed(seed) - cu_count = get_cu_count() + cu_count = num_compute_units() A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5 B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5 @@ -203,7 +203,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): @pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed): torch.manual_seed(seed) - cu_count = get_cu_count() + cu_count = num_compute_units() xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier @@ -222,7 +222,7 @@ def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed): @pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed): torch.manual_seed(seed) - cu_count = get_cu_count() + cu_count = num_compute_units() xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier @@ -267,7 +267,7 @@ def test_rocm_wvsplitk_fp8_kernel( ref_out = torch._scaled_mm( A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS ) - out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, get_cu_count(), BIAS) + out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, num_compute_units(), BIAS) if xnorm: torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8) diff --git a/vllm/model_executor/kernels/linear/mixed_precision/allspark.py b/vllm/model_executor/kernels/linear/mixed_precision/allspark.py index 3baef4542..5f31538e4 100644 --- a/vllm/model_executor/kernels/linear/mixed_precision/allspark.py +++ b/vllm/model_executor/kernels/linear/mixed_precision/allspark.py @@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.allspark_utils import ( check_allspark_supported_dtype_shape, ) from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_ +from vllm.utils.platform_utils import num_compute_units from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig @@ -45,7 +46,7 @@ class AllSparkLinearKernel(MPLinearKernel): # prepare the parameters required for the kernel properties = torch.cuda.get_device_properties(device.index) - sm_count = properties.multi_processor_count + sm_count = num_compute_units(device.index) sm_version = properties.major * 10 + properties.minor gemm_args = {} gemm_args["sm_count"] = sm_count diff --git a/vllm/model_executor/kernels/linear/scaled_mm/rocm.py b/vllm/model_executor/kernels/linear/scaled_mm/rocm.py index 7a9529624..c8370dff5 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/rocm.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/rocm.py @@ -7,7 +7,7 @@ import torch import vllm.envs as envs from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.utils.platform_utils import get_cu_count +from vllm.utils.platform_utils import num_compute_units from vllm.utils.torch_utils import direct_register_custom_op from .ScaledMMLinearKernel import ( @@ -36,7 +36,7 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl( out_dtype, As, Bs, - get_cu_count(), + num_compute_units(), bias, ) # Fallback diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index dbe8e8ef2..9f8b1955e 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -9,6 +9,7 @@ import torch from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import tl, triton +from vllm.utils.platform_utils import num_compute_units from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.v1.attention.backends.registry import AttentionBackendEnum @@ -147,7 +148,7 @@ def matmul_persistent( assert bias is None or bias.dim() == 1, ( "Currently assuming bias is 1D, let Horace know if you run into this" ) - NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + NUM_SMS = num_compute_units(a.device.index) M, K = a.shape K, N = b.shape dtype = a.dtype diff --git a/vllm/model_executor/layers/fla/ops/layernorm_guard.py b/vllm/model_executor/layers/fla/ops/layernorm_guard.py index 89352d12b..74c08e032 100644 --- a/vllm/model_executor/layers/fla/ops/layernorm_guard.py +++ b/vllm/model_executor/layers/fla/ops/layernorm_guard.py @@ -13,8 +13,6 @@ # This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. -from functools import lru_cache - import torch import torch.nn as nn import torch.nn.functional as F @@ -22,6 +20,7 @@ from einops import rearrange from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv, next_power_of_2 +from vllm.utils.platform_utils import num_compute_units from .utils import input_guard @@ -162,15 +161,8 @@ def layer_norm_fwd_kernel( tl.store(Y_base, y, mask=mask) -@lru_cache -def _get_sm_count(device: torch.device) -> int: - """Get and cache the SM count for a given device.""" - props = torch.cuda.get_device_properties(device) - return props.multi_processor_count - - def calc_rows_per_block(M: int, device: torch.device) -> int: - sm_count = _get_sm_count(device) + sm_count = num_compute_units(device.index) rows_per_block = next_power_of_2(cdiv(M, 2 * sm_count)) rows_per_block = min(rows_per_block, 4) return rows_per_block diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 7fa850c85..c1147725c 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types +from vllm.utils.platform_utils import num_compute_units from .quant_utils import pack_cols, unpack_cols @@ -271,7 +272,7 @@ def marlin_make_workspace_new( ) -> torch.Tensor: # In the new marlin kernel, we use the num of threadblocks as workspace # size. The num of threadblocks is sms_count * max_blocks_per_sm. - sms = torch.cuda.get_device_properties(device).multi_processor_count + sms = num_compute_units(device.index) return torch.zeros( sms * max_blocks_per_sm, dtype=torch.int, device=device, requires_grad=False ) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index a6a5ef106..bc51b0e5e 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -11,7 +11,7 @@ from vllm import envs from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform -from vllm.utils.platform_utils import get_cu_count +from vllm.utils.platform_utils import num_compute_units from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) @@ -149,7 +149,7 @@ def rocm_unquantized_gemm_impl( m = weight.shape[0] k = weight.shape[1] - cu_count = get_cu_count() + cu_count = num_compute_units() if use_aiter_triton_gemm(n, m, k, x.dtype): from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 @@ -199,7 +199,7 @@ def rocm_unquantized_gemm_impl( x_view = x.reshape(-1, x.size(-1)) if m > 8 and 0 < n <= 4: - cu_count = get_cu_count() + cu_count = num_compute_units() out = ops.wvSplitK(weight, x_view, cu_count, bias) return out.reshape(*x.shape[:-1], weight.shape[0]) elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None: diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index a445c0aaf..f7df8f813 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -26,6 +26,7 @@ from vllm.utils.deep_gemm import ( m_grouped_fp8_gemm_nt_contiguous, ) from vllm.utils.math_utils import cdiv +from vllm.utils.platform_utils import num_compute_units def _generate_optimal_warmup_m_values( @@ -44,7 +45,7 @@ def _generate_optimal_warmup_m_values( # DeepGEMM's possible block sizes block_ms = [64, 128, 256] block_ns = list(range(16, min(257, n + 1), 16)) - num_sms = torch.cuda.get_device_properties(device).multi_processor_count + num_sms = num_compute_units(device.index) m_values = set() diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index c2fcde4ab..ddd4df418 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -538,6 +538,10 @@ class CudaPlatformBase(Platform): def support_static_graph_mode(cls) -> bool: return True + @classmethod + def num_compute_units(cls, device_id=0): + return torch.cuda.get_device_properties(device_id).multi_processor_count + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 6794c05f5..75e716479 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -692,6 +692,16 @@ class Platform: """ return {} + @classmethod + def num_compute_units(cls, device_id: int = 0) -> int: + """ + Get the number of compute units for the current platform. + (NVIDIA SM / AMD CU / Intel EU) + """ + raise NotImplementedError( + "num_compute_units is not implemented for the current platform." + ) + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index c20c5717f..e1e2ffb1d 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -682,3 +682,7 @@ class RocmPlatform(Platform): @classmethod def support_static_graph_mode(cls) -> bool: return True + + @classmethod + def num_compute_units(cls, device_id=0): + return torch.cuda.get_device_properties(device_id).multi_processor_count diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 5ce3cfba8..caa4305a5 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -277,3 +277,7 @@ class XPUPlatform(Platform): """Copy blocks from XPU to host (CPU).""" _src_cache = src_cache[:, src_block_indices] dst_cache[:, dst_block_indices] = _src_cache.cpu() + + @classmethod + def num_compute_units(cls, device_id: int = 0) -> int: + return torch.xpu.get_device_properties(device_id).max_compute_units diff --git a/vllm/utils/platform_utils.py b/vllm/utils/platform_utils.py index 433c6734e..6dd9ca422 100644 --- a/vllm/utils/platform_utils.py +++ b/vllm/utils/platform_utils.py @@ -24,11 +24,6 @@ def xpu_is_initialized() -> bool: return torch.xpu.is_initialized() -def get_cu_count(device_id: int = 0) -> int: - """Returns the total number of compute units (CU) on single GPU.""" - return torch.cuda.get_device_properties(device_id).multi_processor_count - - def cuda_get_device_properties( device, names: Sequence[str], init_cuda=False ) -> tuple[Any, ...]: @@ -57,3 +52,11 @@ def is_uva_available() -> bool: # UVA requires pinned memory. # TODO: Add more requirements for UVA if needed. return is_pin_memory_available() + + +@cache +def num_compute_units(device_id: int = 0) -> int: + """Get the number of compute units of the current device.""" + from vllm.platforms import current_platform + + return current_platform.num_compute_units(device_id) diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 6d10a9d66..0751b5f0f 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -16,6 +16,7 @@ from vllm.model_executor.layers.attention.mla_attention import ( MLACommonMetadataBuilder, ) from vllm.platforms.interface import DeviceCapability +from vllm.utils.platform_utils import num_compute_units from vllm.v1.attention.backend import ( AttentionCGSupport, AttentionLayer, @@ -74,8 +75,7 @@ class SM100Workspace: # Pre-compute sm_count to avoid recomputing it. Use device 0 as a proxy # (assumes all devices are similar) - properties = torch.cuda.get_device_properties(torch.device("cuda:0")) - self._sm_count = properties.multi_processor_count + self._sm_count = num_compute_units(0) def get_buf(self): return self._workspace_buf diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 37ab14809..163b23b04 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -21,6 +21,7 @@ from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) from vllm.platforms.interface import DeviceCapability +from vllm.utils.platform_utils import num_compute_units from vllm.v1.attention.backend import ( AttentionCGSupport, AttentionLayer, @@ -130,8 +131,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): self.cg_buf_num_splits = None self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8") - device_properties = torch.cuda.get_device_properties(self.device) - num_sms = device_properties.multi_processor_count + num_sms = num_compute_units(self.device.index) if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): self.cg_buf_tile_scheduler_metadata = torch.zeros( diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 799c77d73..e04a7688f 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -15,6 +15,7 @@ from vllm.model_executor.layers.attention.mla_attention import ( ) from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability +from vllm.utils.platform_utils import num_compute_units from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -237,8 +238,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad # DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2) self._init_reorder_batch_threshold(1, supports_spec_as_decode=True) - props = torch.cuda.get_device_properties(device) - sm_count = props.multi_processor_count + sm_count = num_compute_units(device.index) self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 41805e99b..3c56f9fd0 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -9,6 +9,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, has_deep_gemm +from vllm.utils.platform_utils import num_compute_units from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -219,8 +220,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ) self.reorder_batch_threshold += self.num_speculative_tokens - props = torch.cuda.get_device_properties(self.device) - sm_count = props.multi_processor_count + sm_count = num_compute_units(self.device.index) self.num_sms = sm_count self.decode_lens_buffer = torch.empty( diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index b9ca39d8e..bc547585b 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -13,7 +13,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv -from vllm.utils.platform_utils import get_cu_count +from vllm.utils.platform_utils import num_compute_units from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -38,7 +38,7 @@ if current_platform.is_rocm(): return min(65536 // x.element_size(), triton.next_power_of_2(head_dim)) def num_programs(total_tokens): - return min(total_tokens, get_cu_count()) + return min(total_tokens, num_compute_units()) @triton.jit def cp_mha_gather_cache_kernel( diff --git a/vllm/v1/sample/ops/topk_topp_triton.py b/vllm/v1/sample/ops/topk_topp_triton.py index f0291978d..114936129 100644 --- a/vllm/v1/sample/ops/topk_topp_triton.py +++ b/vllm/v1/sample/ops/topk_topp_triton.py @@ -13,6 +13,7 @@ import torch from vllm.triton_utils import tl, triton from vllm.utils.math_utils import next_power_of_2 +from vllm.utils.platform_utils import num_compute_units _TRITON_TABLE_CACHE: dict[tuple[torch.device], tuple[torch.Tensor, torch.Tensor]] = {} _TRITON_BUFFER_CACHE: dict[tuple[torch.device, torch.dtype, int], torch.Tensor] = {} @@ -988,7 +989,7 @@ def apply_top_k_top_p_triton( else: p_ptr = logits # Dummy pointer (won't be read) - num_sm = torch.cuda.get_device_properties(logits.device).multi_processor_count + num_sm = num_compute_units(logits.device.index) NUM_PROGRAMS = min(num_sm, batch_size) # Cache per-Triton Program buffer on each device. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 99b799ea4..f711d1d79 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -98,7 +98,7 @@ from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.utils.math_utils import cdiv, round_up from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib from vllm.utils.nvtx_pytorch_hooks import PytHooks -from vllm.utils.platform_utils import is_pin_memory_available +from vllm.utils.platform_utils import is_pin_memory_available, num_compute_units from vllm.utils.torch_utils import ( get_dtype_size, kv_cache_dtype_str_to_dtype, @@ -909,8 +909,8 @@ class GPUModelRunner( # Note: used for model runner override. def _init_device_properties(self) -> None: """Initialize attributes from torch.cuda.get_device_properties""" - self.device_properties = torch.cuda.get_device_properties(self.device) - self.num_sms = self.device_properties.multi_processor_count + + self.num_sms = num_compute_units(self.device.index) # Note: used for model runner override. def _sync_device(self) -> None: diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 765427683..edbf797b1 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -23,6 +23,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.import_utils import has_deep_gemm +from vllm.utils.platform_utils import num_compute_units from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts logger = init_logger(__name__) @@ -72,8 +73,7 @@ class SMControlContextManager: "SM control is currently only supported on CUDA" ) - props = torch.cuda.get_device_properties(torch.cuda.current_device()) - total_sms = props.multi_processor_count + total_sms = num_compute_units(torch.cuda.current_device().index) assert comm_sms < total_sms self.total_sms = total_sms diff --git a/vllm/v1/worker/xpu_model_runner.py b/vllm/v1/worker/xpu_model_runner.py index 305633058..e2cd49990 100644 --- a/vllm/v1/worker/xpu_model_runner.py +++ b/vllm/v1/worker/xpu_model_runner.py @@ -28,9 +28,6 @@ class XPUModelRunner(GPUModelRunner): # FIXME: To be verified. self.cascade_attn_enabled = False - def _init_device_properties(self) -> None: - self.num_sms = None - def _sync_device(self) -> None: torch.xpu.synchronize()