[XPU] Quick fix for TritonMLA to remove cuda hardcode (#39088)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Chendi.Xue
2026-04-07 11:17:58 -05:00
committed by GitHub
parent 7310555482
commit 92b9afeecd
2 changed files with 3 additions and 2 deletions

View File

@@ -222,7 +222,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
else: else:
self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
elif current_platform.is_xpu(): elif self.unquantized_backend == UnquantizedMoeBackend.XPU:
w13 = layer.w13_weight w13 = layer.w13_weight
w2 = layer.w2_weight w2 = layer.w2_weight

View File

@@ -13,6 +13,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
MLACommonImpl, MLACommonImpl,
MLACommonMetadata, MLACommonMetadata,
) )
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.utils.torch_utils import is_quantized_kv_cache
@@ -116,7 +117,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
if is_quantized_kv_cache(self.kv_cache_dtype): if is_quantized_kv_cache(self.kv_cache_dtype):
self.supports_quant_query_input = False self.supports_quant_query_input = False
self._sm_count = torch.cuda.get_device_properties(0).multi_processor_count self._sm_count = current_platform.num_compute_units()
def _flash_attn_varlen_diff_headdims( def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs