From 92b9afeecde83bb0e0fef8ec8e30e7dae6b43bdc Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Tue, 7 Apr 2026 11:17:58 -0500 Subject: [PATCH] [XPU] Quick fix for TritonMLA to remove cuda hardcode (#39088) Signed-off-by: Chendi Xue Co-authored-by: Kunshang Ji --- .../layers/fused_moe/unquantized_fused_moe_method.py | 2 +- vllm/v1/attention/backends/mla/triton_mla.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 759f77b36..190821562 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -222,7 +222,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) else: self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) - elif current_platform.is_xpu(): + elif self.unquantized_backend == UnquantizedMoeBackend.XPU: w13 = layer.w13_weight w2 = layer.w2_weight diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index e8b09a436..0f8eb1c49 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -13,6 +13,7 @@ from vllm.model_executor.layers.attention.mla_attention import ( MLACommonImpl, MLACommonMetadata, ) +from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability from vllm.triton_utils import triton from vllm.utils.torch_utils import is_quantized_kv_cache @@ -116,7 +117,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): if is_quantized_kv_cache(self.kv_cache_dtype): self.supports_quant_query_input = False - self._sm_count = torch.cuda.get_device_properties(0).multi_processor_count + self._sm_count = current_platform.num_compute_units() def _flash_attn_varlen_diff_headdims( self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs