[XPU] support MLA model on Intel GPU (#37143)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -170,6 +170,7 @@ class xpu_ops:
|
|||||||
num_splits=0,
|
num_splits=0,
|
||||||
return_softmax_lse: bool | None = False,
|
return_softmax_lse: bool | None = False,
|
||||||
s_aux: torch.Tensor | None = None,
|
s_aux: torch.Tensor | None = None,
|
||||||
|
return_attn_probs: bool | None = False,
|
||||||
):
|
):
|
||||||
assert cu_seqlens_k is not None or seqused_k is not None, (
|
assert cu_seqlens_k is not None or seqused_k is not None, (
|
||||||
"cu_seqlens_k or seqused_k must be provided"
|
"cu_seqlens_k or seqused_k must be provided"
|
||||||
|
|||||||
@@ -1059,6 +1059,10 @@ except ImportError:
|
|||||||
"MLA models using TRITON_MLA will require flash_attn. "
|
"MLA models using TRITON_MLA will require flash_attn. "
|
||||||
"AITER_MLA backends use aiter kernels instead."
|
"AITER_MLA backends use aiter kernels instead."
|
||||||
)
|
)
|
||||||
|
elif current_platform.is_xpu():
|
||||||
|
from vllm._xpu_ops import xpu_ops as ops
|
||||||
|
|
||||||
|
flash_attn_varlen_func = ops.flash_attn_varlen_func # type: ignore[no-redef]
|
||||||
|
|
||||||
|
|
||||||
def dynamic_per_batched_tensor_quant(
|
def dynamic_per_batched_tensor_quant(
|
||||||
|
|||||||
@@ -165,6 +165,16 @@ class QuantFP8(CustomOp):
|
|||||||
# Fallback to CUDA implementation
|
# Fallback to CUDA implementation
|
||||||
return self.forward_cuda(x, scale, scale_ub)
|
return self.forward_cuda(x, scale, scale_ub)
|
||||||
|
|
||||||
|
def forward_xpu(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
scale: torch.Tensor | None = None,
|
||||||
|
scale_ub: torch.Tensor | None = None,
|
||||||
|
use_triton: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# XPU can use same code path as CUDA.
|
||||||
|
return self.forward_cuda(x, scale, scale_ub, use_triton)
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
|
|||||||
@@ -160,7 +160,6 @@ class XPUPlatform(Platform):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
model_config = vllm_config.model_config
|
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
# in V1(or with chunked prefill) block_size is 64
|
# in V1(or with chunked prefill) block_size is 64
|
||||||
if cache_config and not cache_config.user_specified_block_size:
|
if cache_config and not cache_config.user_specified_block_size:
|
||||||
@@ -209,17 +208,6 @@ class XPUPlatform(Platform):
|
|||||||
if vllm_config.kv_transfer_config is not None:
|
if vllm_config.kv_transfer_config is not None:
|
||||||
vllm_config.kv_transfer_config.enable_permute_local_kv = True
|
vllm_config.kv_transfer_config.enable_permute_local_kv = True
|
||||||
|
|
||||||
if model_config and model_config.use_mla:
|
|
||||||
logger.info(
|
|
||||||
"MLA is enabled on a non-GPU platform; forcing chunked "
|
|
||||||
"prefill and prefix caching to be disabled."
|
|
||||||
)
|
|
||||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
|
||||||
vllm_config.scheduler_config.max_num_batched_tokens = max(
|
|
||||||
vllm_config.model_config.max_model_len,
|
|
||||||
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
|
||||||
)
|
|
||||||
|
|
||||||
# In some cases, the internal memory type cache can misdetect GPU
|
# In some cases, the internal memory type cache can misdetect GPU
|
||||||
# memory as host memory, also leading to invalid memory access.
|
# memory as host memory, also leading to invalid memory access.
|
||||||
# This cache can be disabled by setting UCX_MEMTYPE_CACHE=n.
|
# This cache can be disabled by setting UCX_MEMTYPE_CACHE=n.
|
||||||
|
|||||||
Reference in New Issue
Block a user