[XPU] Quick fix for TritonMLA to remove cuda hardcode (#39088)
Signed-off-by: Chendi Xue <chendi.xue@intel.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -222,7 +222,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
||||||
else:
|
else:
|
||||||
self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
||||||
elif current_platform.is_xpu():
|
elif self.unquantized_backend == UnquantizedMoeBackend.XPU:
|
||||||
w13 = layer.w13_weight
|
w13 = layer.w13_weight
|
||||||
w2 = layer.w2_weight
|
w2 = layer.w2_weight
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
|||||||
MLACommonImpl,
|
MLACommonImpl,
|
||||||
MLACommonMetadata,
|
MLACommonMetadata,
|
||||||
)
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.interface import DeviceCapability
|
from vllm.platforms.interface import DeviceCapability
|
||||||
from vllm.triton_utils import triton
|
from vllm.triton_utils import triton
|
||||||
from vllm.utils.torch_utils import is_quantized_kv_cache
|
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||||
@@ -116,7 +117,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
self.supports_quant_query_input = False
|
self.supports_quant_query_input = False
|
||||||
|
|
||||||
self._sm_count = torch.cuda.get_device_properties(0).multi_processor_count
|
self._sm_count = current_platform.num_compute_units()
|
||||||
|
|
||||||
def _flash_attn_varlen_diff_headdims(
|
def _flash_attn_varlen_diff_headdims(
|
||||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||||
|
|||||||
Reference in New Issue
Block a user