[XPU] Quick fix for TritonMLA to remove cuda hardcode (#39088)
Signed-off-by: Chendi Xue <chendi.xue@intel.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -222,7 +222,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
||||
else:
|
||||
self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
||||
elif current_platform.is_xpu():
|
||||
elif self.unquantized_backend == UnquantizedMoeBackend.XPU:
|
||||
w13 = layer.w13_weight
|
||||
w2 = layer.w2_weight
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||
@@ -116,7 +117,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
self.supports_quant_query_input = False
|
||||
|
||||
self._sm_count = torch.cuda.get_device_properties(0).multi_processor_count
|
||||
self._sm_count = current_platform.num_compute_units()
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||
|
||||
Reference in New Issue
Block a user