From 14771f715085f5026919d30048329c3e822d4b3c Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Wed, 25 Mar 2026 17:43:42 +0800 Subject: [PATCH] [XPU] support MLA model on Intel GPU (#37143) Signed-off-by: Kunshang Ji --- vllm/_xpu_ops.py | 1 + .../model_executor/layers/attention/mla_attention.py | 4 ++++ .../layers/quantization/input_quant_fp8.py | 10 ++++++++++ vllm/platforms/xpu.py | 12 ------------ 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/vllm/_xpu_ops.py b/vllm/_xpu_ops.py index 604f3412e..376375550 100644 --- a/vllm/_xpu_ops.py +++ b/vllm/_xpu_ops.py @@ -170,6 +170,7 @@ class xpu_ops: num_splits=0, return_softmax_lse: bool | None = False, s_aux: torch.Tensor | None = None, + return_attn_probs: bool | None = False, ): assert cu_seqlens_k is not None or seqused_k is not None, ( "cu_seqlens_k or seqused_k must be provided" diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index a51a8c2a7..0215ec1a0 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -1059,6 +1059,10 @@ except ImportError: "MLA models using TRITON_MLA will require flash_attn. " "AITER_MLA backends use aiter kernels instead." ) + elif current_platform.is_xpu(): + from vllm._xpu_ops import xpu_ops as ops + + flash_attn_varlen_func = ops.flash_attn_varlen_func # type: ignore[no-redef] def dynamic_per_batched_tensor_quant( diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index 6fa85436d..19c9414a2 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -165,6 +165,16 @@ class QuantFP8(CustomOp): # Fallback to CUDA implementation return self.forward_cuda(x, scale, scale_ub) + def forward_xpu( + self, + x: torch.Tensor, + scale: torch.Tensor | None = None, + scale_ub: torch.Tensor | None = None, + use_triton: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + # XPU can use same code path as CUDA. + return self.forward_cuda(x, scale, scale_ub, use_triton) + def forward_native( self, x: torch.Tensor, diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 5d39dfceb..3bb76425e 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -160,7 +160,6 @@ class XPUPlatform(Platform): @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config = vllm_config.cache_config - model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config # in V1(or with chunked prefill) block_size is 64 if cache_config and not cache_config.user_specified_block_size: @@ -209,17 +208,6 @@ class XPUPlatform(Platform): if vllm_config.kv_transfer_config is not None: vllm_config.kv_transfer_config.enable_permute_local_kv = True - if model_config and model_config.use_mla: - logger.info( - "MLA is enabled on a non-GPU platform; forcing chunked " - "prefill and prefix caching to be disabled." - ) - vllm_config.scheduler_config.enable_chunked_prefill = False - vllm_config.scheduler_config.max_num_batched_tokens = max( - vllm_config.model_config.max_model_len, - vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS, - ) - # In some cases, the internal memory type cache can misdetect GPU # memory as host memory, also leading to invalid memory access. # This cache can be disabled by setting UCX_MEMTYPE_CACHE=n.