[XPU] support MLA model on Intel GPU (#37143)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji
2026-03-25 17:43:42 +08:00
committed by GitHub
parent 189ddefbfd
commit 14771f7150
4 changed files with 15 additions and 12 deletions

View File

@@ -170,6 +170,7 @@ class xpu_ops:
num_splits=0,
return_softmax_lse: bool | None = False,
s_aux: torch.Tensor | None = None,
return_attn_probs: bool | None = False,
):
assert cu_seqlens_k is not None or seqused_k is not None, (
"cu_seqlens_k or seqused_k must be provided"

View File

@@ -1059,6 +1059,10 @@ except ImportError:
"MLA models using TRITON_MLA will require flash_attn. "
"AITER_MLA backends use aiter kernels instead."
)
elif current_platform.is_xpu():
from vllm._xpu_ops import xpu_ops as ops
flash_attn_varlen_func = ops.flash_attn_varlen_func # type: ignore[no-redef]
def dynamic_per_batched_tensor_quant(

View File

@@ -165,6 +165,16 @@ class QuantFP8(CustomOp):
# Fallback to CUDA implementation
return self.forward_cuda(x, scale, scale_ub)
def forward_xpu(
self,
x: torch.Tensor,
scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None,
use_triton: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
# XPU can use same code path as CUDA.
return self.forward_cuda(x, scale, scale_ub, use_triton)
def forward_native(
self,
x: torch.Tensor,

View File

@@ -160,7 +160,6 @@ class XPUPlatform(Platform):
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
# in V1(or with chunked prefill) block_size is 64
if cache_config and not cache_config.user_specified_block_size:
@@ -209,17 +208,6 @@ class XPUPlatform(Platform):
if vllm_config.kv_transfer_config is not None:
vllm_config.kv_transfer_config.enable_permute_local_kv = True
if model_config and model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled."
)
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.model_config.max_model_len,
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)
# In some cases, the internal memory type cache can misdetect GPU
# memory as host memory, also leading to invalid memory access.
# This cache can be disabled by setting UCX_MEMTYPE_CACHE=n.