[XPU] support MLA model on Intel GPU (#37143)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -170,6 +170,7 @@ class xpu_ops:
|
||||
num_splits=0,
|
||||
return_softmax_lse: bool | None = False,
|
||||
s_aux: torch.Tensor | None = None,
|
||||
return_attn_probs: bool | None = False,
|
||||
):
|
||||
assert cu_seqlens_k is not None or seqused_k is not None, (
|
||||
"cu_seqlens_k or seqused_k must be provided"
|
||||
|
||||
@@ -1059,6 +1059,10 @@ except ImportError:
|
||||
"MLA models using TRITON_MLA will require flash_attn. "
|
||||
"AITER_MLA backends use aiter kernels instead."
|
||||
)
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._xpu_ops import xpu_ops as ops
|
||||
|
||||
flash_attn_varlen_func = ops.flash_attn_varlen_func # type: ignore[no-redef]
|
||||
|
||||
|
||||
def dynamic_per_batched_tensor_quant(
|
||||
|
||||
@@ -165,6 +165,16 @@ class QuantFP8(CustomOp):
|
||||
# Fallback to CUDA implementation
|
||||
return self.forward_cuda(x, scale, scale_ub)
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
scale_ub: torch.Tensor | None = None,
|
||||
use_triton: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# XPU can use same code path as CUDA.
|
||||
return self.forward_cuda(x, scale, scale_ub, use_triton)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
|
||||
@@ -160,7 +160,6 @@ class XPUPlatform(Platform):
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
cache_config = vllm_config.cache_config
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
# in V1(or with chunked prefill) block_size is 64
|
||||
if cache_config and not cache_config.user_specified_block_size:
|
||||
@@ -209,17 +208,6 @@ class XPUPlatform(Platform):
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
vllm_config.kv_transfer_config.enable_permute_local_kv = True
|
||||
|
||||
if model_config and model_config.use_mla:
|
||||
logger.info(
|
||||
"MLA is enabled on a non-GPU platform; forcing chunked "
|
||||
"prefill and prefix caching to be disabled."
|
||||
)
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = max(
|
||||
vllm_config.model_config.max_model_len,
|
||||
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
|
||||
# In some cases, the internal memory type cache can misdetect GPU
|
||||
# memory as host memory, also leading to invalid memory access.
|
||||
# This cache can be disabled by setting UCX_MEMTYPE_CACHE=n.
|
||||
|
||||
Reference in New Issue
Block a user