diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index ab8dfb036..36775152f 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -38,10 +38,9 @@ docker run \ python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -cc.cudagraph_mode=NONE python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp - python3 examples/offline_inference/basic/generate.py --model Intel/Qwen2.5-0.5B-W4A16-G128-AutoRound-LLMC-TEST-ONLY --enforce-eager python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --attention-backend=TRITON_ATTN cd tests - pytest -v -s v1/core + pytest -v -s v1/core --ignore=v1/core/test_reset_prefix_cache_e2e.py pytest -v -s v1/engine pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu index f63ce2c50..04051827b 100644 --- a/docker/Dockerfile.xpu +++ b/docker/Dockerfile.xpu @@ -1,8 +1,8 @@ -FROM intel/deep-learning-essentials:2025.2.2-0-devel-ubuntu24.04 AS vllm-base +FROM intel/deep-learning-essentials:2025.3.2-0-devel-ubuntu24.04 AS vllm-base RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \ echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list && \ - add-apt-repository -y ppa:kobuk-team/intel-graphics-staging + add-apt-repository -y ppa:kobuk-team/intel-graphics RUN apt clean && apt-get update -y && \ apt-get install -y --no-install-recommends --fix-missing \ @@ -25,10 +25,13 @@ RUN apt clean && apt-get update -y && \ RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1 RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.12 1 -RUN apt install -y libze1 libze-dev libze-intel-gpu1 intel-opencl-icd libze-intel-gpu-raytracing intel-ocloc +RUN apt update && apt upgrade -y && \ + apt install -y libze1 libze-dev libze-intel-gpu1 intel-opencl-icd libze-intel-gpu-raytracing intel-ocloc && \ + apt install -y intel-oneapi-compiler-dpcpp-cpp-2025.3 + # This oneccl contains the BMG support which is not the case for default version of oneapi 2025.2. -ARG ONECCL_INSTALLER="intel-oneccl-2021.15.7.6_offline.sh" +ARG ONECCL_INSTALLER="intel-oneccl-2021.15.7.8_offline.sh" RUN wget "https://github.com/uxlfoundation/oneCCL/releases/download/2021.15.7/${ONECCL_INSTALLER}" && \ bash "${ONECCL_INSTALLER}" -a --silent --eula accept && \ rm "${ONECCL_INSTALLER}" && \ @@ -85,6 +88,9 @@ RUN python3 -m pip install -e tests/vllm_test_utils ENV NIXL_VERSION=0.7.0 RUN python3 /workspace/vllm/tools/install_nixl_from_source_ubuntu.py +# FIX triton +RUN --mount=type=cache,target=/root/.cache/pip pip uninstall triton triton-xpu -y && pip install triton-xpu==3.6.0 --extra-index-url=https://download.pytorch.org/whl/xpu + # PyJWT-2.7.0 will influence some wheel behaviors, remove its dist-info to avoid conflicts RUN rm /usr/lib/python3/dist-packages/PyJWT-2.7.0.dist-info/ -rf diff --git a/requirements/xpu.txt b/requirements/xpu.txt index c1dc4195b..6fde5b8f9 100644 --- a/requirements/xpu.txt +++ b/requirements/xpu.txt @@ -11,8 +11,8 @@ jinja2>=3.1.6 datasets # for benchmark scripts numba == 0.61.2 # Required for N-gram speculative decoding --extra-index-url=https://download.pytorch.org/whl/xpu -torch==2.9.0+xpu +torch==2.10.0+xpu torchaudio torchvision -intel-extension-for-pytorch @ https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.9.10.post0%2Bxpu-cp312-cp312-linux_x86_64.whl +vllm_xpu_kernels @ https://github.com/vllm-project/vllm-xpu-kernels/releases/download/v0.1.0/vllm_xpu_kernels-0.1.0-cp312-cp312-linux_x86_64.whl \ No newline at end of file diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index 239f5376e..22133eaef 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -1,273 +1,59 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING import torch +from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func from vllm.logger import init_logger -from vllm.platforms import current_platform logger = init_logger(__name__) -try: - import intel_extension_for_pytorch as ipex -except ImportError as e: - logger.debug("Import error msg: %s", e.msg) +if TYPE_CHECKING: + + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + from torch.library import impl_abstract as register_fake + +if hasattr(torch.ops._xpu_C, "fp8_gemm_w8a16"): + + @register_fake("_xpu_C::fp8_gemm_w8a16") + def _fp8_gemm_w8a16_fake( + input: torch.Tensor, + q_weight: torch.Tensor, + weight_scale: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + input_2d = input.view(-1, input.shape[-1]) + M = input_2d.size(0) + N = q_weight.size(1) + return torch.empty((M, N), dtype=input.dtype, device=input.device) + + +if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"): + + @register_fake("_xpu_C::int4_gemm_w4a16") + def _int4_gemm_w4a16_fake( + input: torch.Tensor, + q_weight: torch.Tensor, + bias: torch.Tensor | None, + weight_scale: torch.Tensor, + qzeros: torch.Tensor, + group_size: int, + group_idx: torch.Tensor | None = None, + ) -> torch.Tensor: + input_2d = input.view(-1, input.shape[-1]) + M = input_2d.size(0) + N = q_weight.size(1) + return torch.empty((M, N), dtype=input.dtype, device=input.device) class ipex_ops: - @staticmethod - def _reshape_activation_tensor( - x: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - num = x.size(0) - d = x.size(1) // 2 - x = x.reshape(num, 2, d) - x1, x2 = torch.chunk(x, chunks=2, dim=1) - x1 = x1.reshape(num, d) - x2 = x2.reshape(num, d) - return x1, x2 - - @staticmethod - def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - ipex.llm.functional.silu_and_mul(x, out) - - @staticmethod - def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - ipex.llm.functional.gelu_and_mul(x, out) - - @staticmethod - def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - ipex.llm.functional.gelu_and_mul(x, out) - - @staticmethod - def gelu_fast(x: torch.Tensor) -> torch.Tensor: - return torch.nn.functional.gelu(x) - - @staticmethod - def gelu_new(x: torch.Tensor) -> torch.Tensor: - return torch.nn.functional.gelu(x) - - @staticmethod - def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: - ipex.llm.functional.gelu_quick(x, out) - - @staticmethod - def paged_attention_v1( - out: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - num_kv_heads: int, - scale: float, - block_tables: torch.Tensor, - context_lens: torch.Tensor, - block_size: int, - max_context_len: int, - alibi_slopes: torch.Tensor | None, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, - tp_rank: int = 0, - blocksparse_local_blocks: int = 0, - blocksparse_vert_stride: int = 0, - blocksparse_block_size: int = 64, - blocksparse_head_sliding_step: int = 0, - ) -> None: - assert kv_cache_dtype == "auto" - num_heads = out.size(1) - num_queries_per_tokens = num_heads // num_kv_heads - ipex.llm.modules.PagedAttention.single_query_kv_attention( - out, - query.contiguous(), - key_cache.view_as(value_cache), - value_cache, - num_queries_per_tokens, - scale, - block_tables, - context_lens, - block_size, - max_context_len, - alibi_slopes, - ) - - @staticmethod - def paged_attention_v2( - out: torch.Tensor, - exp_sum: torch.Tensor, - max_logits: torch.Tensor, - tmp_out: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - num_kv_heads: int, - scale: float, - block_tables: torch.Tensor, - context_lens: torch.Tensor, - block_size: int, - max_context_len: int, - alibi_slopes: torch.Tensor | None, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, - tp_rank: int = 0, - blocksparse_local_blocks: int = 0, - blocksparse_vert_stride: int = 0, - blocksparse_block_size: int = 64, - blocksparse_head_sliding_step: int = 0, - ) -> None: - assert kv_cache_dtype == "auto" - num_heads = out.size(1) - num_queries_per_tokens = num_heads // num_kv_heads - ipex.llm.modules.PagedAttention.single_query_kv_attention( - out, - query.contiguous(), - key_cache.view_as(value_cache), - value_cache, - num_queries_per_tokens, - scale, - block_tables, - context_lens, - block_size, - max_context_len, - alibi_slopes, - ) - - @staticmethod - def rotary_embedding( - positions: torch.Tensor, # [batch_size, seq_len] - query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size] - key: torch.Tensor, # [batch_size, seq_len, num_kv_heads*head_size] - head_size: int, - cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim] - is_neox: bool, - ) -> None: - rot_dim = cos_sin_cache.size(1) - ipex.llm.functional.rotary_embedding_batched( - positions, query, key, head_size, cos_sin_cache, is_neox, rot_dim - ) - - @staticmethod - def rms_norm( - input: torch.Tensor, weight: torch.Tensor, epsilon: float - ) -> torch.Tensor: - out = torch.empty_like(input) - torch.ops.torch_ipex.rms_norm_vllm(out, input.contiguous(), weight, epsilon) - return out - - @staticmethod - def fused_add_rms_norm( - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - epsilon: float, - ) -> None: - torch.ops.torch_ipex.fused_add_rms_norm_vllm(input, residual, weight, epsilon) - - @staticmethod - def varlen_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - out: torch.Tensor, - seqlen_q: torch.Tensor, - seqlen_k: torch.Tensor, - alibi_slopes: torch.Tensor | None, - max_seqlen_q: int, - max_seqlen_k: int, - pdropout: float, - softmax_scale: float, - zero_tensors: bool, - is_causal: bool, - return_softmax: bool, - gen_: torch.Generator, - window_size_left: float, - window_size_right: float, - logits_soft_cap: float, - ) -> None: - if ipex.__version__.endswith("cpu"): - if logits_soft_cap != 0.0: - raise ValueError("IPEX CPU does not support logits_soft_cap") - assert alibi_slopes is None - assert window_size_left < 0 and window_size_right < 0 - ipex.llm.functional.varlen_attention( - query.contiguous(), - key.contiguous(), - value.contiguous(), - out, - seqlen_q.int(), - seqlen_k.int(), - max_seqlen_q, - max_seqlen_k, - pdropout, - softmax_scale, - zero_tensors, - is_causal, - return_softmax, - gen_, - ) - else: # XPU build - ipex.llm.functional.varlen_attention( - query.contiguous(), - key.contiguous(), - value.contiguous(), - out, - seqlen_q.int(), - seqlen_k.int(), - alibi_slopes, - max_seqlen_q, - max_seqlen_k, - pdropout, - softmax_scale, - zero_tensors, - is_causal, - return_softmax, - gen_, - window_size_left, - window_size_right, - logits_soft_cap, - ) - - @staticmethod - def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, - ) -> None: - assert kv_cache_dtype == "auto" - ipex.llm.modules.PagedAttention.reshape_and_cache( - key, value, key_cache, value_cache, slot_mapping - ) - - @staticmethod - def reshape_and_cache_flash( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - k_scale: torch.Tensor | None = None, - v_scale: torch.Tensor | None = None, - k_scale_float: float = 1.0, - v_scale_float: float = 1.0, - ) -> None: - ipex.llm.modules.PagedAttention.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - slot_mapping, - kv_cache_dtype, - k_scale_float, - v_scale_float, - ) - @staticmethod def flash_attn_varlen_func( q: torch.Tensor, @@ -295,8 +81,21 @@ class ipex_ops: k_descale=None, v_descale=None, num_splits=0, + return_softmax_lse: bool | None = False, s_aux: torch.Tensor | None = None, ): + assert cu_seqlens_k is not None or seqused_k is not None, ( + "cu_seqlens_k or seqused_k must be provided" + ) + assert cu_seqlens_k is None or seqused_k is None, ( + "cu_seqlens_k and seqused_k cannot be provided at the same time" + ) + assert block_table is None or seqused_k is not None, ( + "when enable block_table, seqused_k is needed" + ) + assert block_table is not None or cu_seqlens_k is not None, ( + "when block_table is disabled, cu_seqlens_k is needed" + ) if out is None: out = torch.empty(q.shape, dtype=q.dtype, device=q.device) real_window_size: tuple[int, int] @@ -304,56 +103,31 @@ class ipex_ops: real_window_size = (-1, -1) else: assert len(window_size) == 2 - real_window_size = (window_size[0], window_size[1]) + real_window_size = (window_size[0], window_size[1]) # noqa: F841 + # In encode attention, v maybe not contiguous and current + # kernel can't handle it if block_table is None: - assert cu_seqlens_k is not None, ( - "cu_seqlens_k can't be None when calling varlen_attention." - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - ipex_ops.varlen_attention( - q.contiguous(), - k.contiguous(), - v.contiguous(), - out, - cu_seqlens_q, - cu_seqlens_k, - None, - max_seqlen_q, - max_seqlen_k, - 0.0, - softmax_scale, - False, - causal, - False, - None, - real_window_size[0], - real_window_size[1], - -1, - ) - return out - else: - return ipex.llm.modules.PagedAttention.flash_attn_varlen_func( - out, - q.contiguous(), - k, - v, - cu_seqlens_q, - seqused_k, - max_seqlen_q, - max_seqlen_k, - softmax_scale, - causal, - block_table, - alibi_slopes, - sink=s_aux, - softcap=softcap, - window_size_left=real_window_size[0], - window_size_right=real_window_size[1], - k_scale=1.0, - v_scale=1.0, - ) + v = v.contiguous() + return flash_attn_varlen_func( + out=out, + q=q.contiguous(), + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_k=seqused_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=softmax_scale, + causal=causal, + block_table=block_table, + s_aux=s_aux, + window_size=real_window_size, + # alibi_slopes = alibi_slopes, + # softcap=softcap, + return_softmax_lse=return_softmax_lse, + ) @staticmethod def get_scheduler_metadata( @@ -382,64 +156,3 @@ class ipex_ops: "get_scheduler_metadata is not implemented for ipex_ops, returning None." ) return None - - @staticmethod - def swap_blocks( - src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor - ) -> None: - torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore - - @staticmethod - def scaled_fp8_quant( - input: torch.Tensor, - scale: torch.Tensor | None = None, - num_token_padding: int | None = None, - scale_ub: torch.Tensor | None = None, - use_per_token_if_dynamic: bool = False, - output: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function is designed for both static and dynamic quantization: - If you provide the scale, it will use static scaling and if you omit - it, the scale will be determined dynamically. Currently, XPU platform - only supports dynamic quantization. The function also allows optional - padding of the output tensors for downstream kernels that will benefit - from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert input.ndim == 2 - shape: tuple[int, int] | torch.Size = input.shape - out_dtype: torch.dtype = current_platform.fp8_dtype() - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - if output is None: - output = torch.empty(shape, device=input.device, dtype=out_dtype) - else: - assert num_token_padding is None, ( - "padding not supported if output passed in" - ) - assert output.dtype == out_dtype - assert scale is None, "only dynamic fp8 quantization supported on XPU" - assert not use_per_token_if_dynamic, ( - "per token dynamic fp8 quantization not supported on XPU" - ) - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - torch.ops.torch_ipex.dynamic_scaled_fp8_quant(output, input, scale) - - return output, scale diff --git a/vllm/config/model.py b/vllm/config/model.py index 563f8ac56..48ff44ac9 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -877,7 +877,6 @@ class ModelConfig: overrides = [ "gptq_marlin", "awq_marlin", - "ipex", "inc", "moe_wna16", "modelopt", diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index b53a37a31..3e00d21d5 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -129,12 +129,8 @@ class SiluAndMul(CustomOp): def __init__(self, *, compile_native: bool = True): super().__init__(compile_native=compile_native) - if current_platform.is_cuda_alike(): + if current_platform.is_cuda_alike() or current_platform.is_xpu(): self.op = torch.ops._C.silu_and_mul - elif current_platform.is_xpu(): - from vllm._ipex_ops import ipex_ops - - self.op = ipex_ops.silu_and_mul elif current_platform.is_cpu(): self._forward_method = self.forward_native @@ -152,11 +148,7 @@ class SiluAndMul(CustomOp): return out def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: - d = x.shape[-1] // 2 - output_shape = x.shape[:-1] + (d,) - out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - self.op(out, x) - return out + return self.forward_cuda(x) # --8<-- [start:mul_and_silu] @@ -175,12 +167,8 @@ class MulAndSilu(CustomOp): def __init__(self): super().__init__() - if current_platform.is_cuda_alike(): + if current_platform.is_cuda_alike() or current_platform.is_xpu(): self.op = torch.ops._C.mul_and_silu - elif current_platform.is_xpu(): - from vllm._ipex_ops import ipex_ops - - self.op = ipex_ops.silu_and_mul elif current_platform.is_cpu(): self._forward_method = self.forward_native @@ -196,8 +184,8 @@ class MulAndSilu(CustomOp): self.op(out, x) return out - # TODO implement forward_xpu for MulAndSilu - # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + return self.forward_cuda(x) # --8<-- [start:gelu_and_mul_sparse] @@ -278,7 +266,11 @@ class GeluAndMul(CustomOp): self.approximate = approximate if approximate not in ("none", "tanh"): raise ValueError(f"Unknown approximate mode: {approximate}") - if current_platform.is_cuda_alike() or current_platform.is_cpu(): + if ( + current_platform.is_cuda_alike() + or current_platform.is_cpu() + or current_platform.is_xpu() + ): if approximate == "none": self.op = torch.ops._C.gelu_and_mul elif approximate == "tanh": @@ -289,13 +281,6 @@ class GeluAndMul(CustomOp): "with torch.compile. For native implementation, fallback to 'none' " "approximation. The custom kernel implementation is unaffected." ) - elif current_platform.is_xpu(): - from vllm._ipex_ops import ipex_ops - - if approximate == "none": - self.op = ipex_ops.gelu_and_mul - else: - self.op = ipex_ops.gelu_tanh_and_mul def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" @@ -314,11 +299,7 @@ class GeluAndMul(CustomOp): return out def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: - d = x.shape[-1] // 2 - output_shape = x.shape[:-1] + (d,) - out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - self.op(out, x) - return out + return self.forward_cuda(x) def extra_repr(self) -> str: return f"approximate={repr(self.approximate)}" @@ -401,12 +382,12 @@ class NewGELU(CustomOp): def __init__(self): super().__init__() - if current_platform.is_cuda_alike() or current_platform.is_cpu(): + if ( + current_platform.is_cuda_alike() + or current_platform.is_cpu() + or current_platform.is_xpu() + ): self.op = torch.ops._C.gelu_new - elif current_platform.is_xpu(): - from vllm._ipex_ops import ipex_ops - - self.op = ipex_ops.gelu_new def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" @@ -419,7 +400,7 @@ class NewGELU(CustomOp): return out def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: - return self.op(x) + return self.forward_cuda(x) # --8<-- [start:gelu_fast] @@ -429,12 +410,12 @@ class FastGELU(CustomOp): def __init__(self): super().__init__() - if current_platform.is_cuda_alike() or current_platform.is_cpu(): + if ( + current_platform.is_cuda_alike() + or current_platform.is_cpu() + or current_platform.is_xpu() + ): self.op = torch.ops._C.gelu_fast - elif current_platform.is_xpu(): - from vllm._ipex_ops import ipex_ops - - self.op = ipex_ops.gelu_fast def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" @@ -446,7 +427,7 @@ class FastGELU(CustomOp): return out def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: - return self.op(x) + return self.forward_cuda(x) # --8<-- [start:quick_gelu] @@ -457,12 +438,12 @@ class QuickGELU(CustomOp): def __init__(self): super().__init__() - if current_platform.is_cuda_alike() or current_platform.is_cpu(): + if ( + current_platform.is_cuda_alike() + or current_platform.is_cpu() + or current_platform.is_xpu() + ): self.op = torch.ops._C.gelu_quick - elif current_platform.is_xpu(): - from vllm._ipex_ops import ipex_ops - - self.op = ipex_ops.gelu_quick def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" @@ -474,12 +455,7 @@ class QuickGELU(CustomOp): return out def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: - out = torch.empty_like(x) - self.op(out, x) - return out - - # TODO implement forward_xpu for QuickGELU - # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + return self.forward_cuda(x) # --8<-- [start:relu2] diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 2db8ce2bd..3b669c559 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -231,24 +231,7 @@ class RMSNorm(CustomOp): x: torch.Tensor, residual: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if self.variance_size_override is not None: - return self.forward_native(x, residual) - - from vllm._ipex_ops import ipex_ops as ops - - if residual is not None: - ops.fused_add_rms_norm( - x, - residual, - self.weight.data, - self.variance_epsilon, - ) - return x, residual - return ops.rms_norm( - x, - self.weight.data, - self.variance_epsilon, - ) + return self.forward_cuda(x, residual) def extra_repr(self) -> str: s = f"hidden_size={self.weight.data.size(0)}" diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 61d86cea4..bbd7267fd 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -60,8 +60,6 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "ModelOptFp8LinearMethod", "ModelOptFp8PcPtLinearMethod", "ModelOptFp8PbWoLinearMethod", - "IPEXAWQLinearMethod", - "IPEXGPTQLinearMethod", "QuarkLinearMethod", "ModelOptNvFp4LinearMethod", "PetitNvFp4LinearMethod", diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index cc0fdfa8e..82de32af3 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -24,7 +24,6 @@ QuantizationMethods = Literal[ "compressed-tensors", "bitsandbytes", "experts_int8", - "ipex", "quark", "moe_wna16", "torchao", @@ -41,7 +40,6 @@ DEPRECATED_QUANTIZATION_METHODS = [ "fbgemm_fp8", "fp_quant", "experts_int8", - "ipex", "petit_nvfp4", ] @@ -121,7 +119,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .gptq import GPTQConfig from .gptq_marlin import GPTQMarlinConfig from .inc import INCConfig - from .ipex_quant import IPEXConfig from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config from .moe_wna16 import MoeWNA16Config from .mxfp4 import Mxfp4Config @@ -144,7 +141,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "bitsandbytes": BitsAndBytesConfig, "ptpc_fp8": PTPCFp8Config, "experts_int8": ExpertsInt8Config, - "ipex": IPEXConfig, "quark": QuarkConfig, "moe_wna16": MoeWNA16Config, "torchao": TorchAOConfig, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index a8467b5f0..9b7d65433 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -184,39 +184,10 @@ class Fp8Config(QuantizationConfig): def get_xpu_quant_method( self, layer: torch.nn.Module, prefix: str ) -> "QuantizeMethodBase | None": - from vllm.model_executor.layers.quantization.ipex_quant import ( - XPUFp8LinearMethod, - XPUFp8MoEMethod, + raise NotImplementedError( + "FP8 quantization is not supported during xpu kernel migration." ) - fp8_config = Fp8Config( - is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized, - activation_scheme=self.activation_scheme, - ignored_layers=self.ignored_layers, - weight_block_size=self.weight_block_size, - ) - - if isinstance(layer, LinearBase): - if is_layer_skipped( - prefix=prefix, - ignored_layers=self.ignored_layers, - fused_mapping=self.packed_modules_mapping, - ): - return UnquantizedLinearMethod() - return XPUFp8LinearMethod(fp8_config) - elif isinstance(layer, FusedMoE): - if is_layer_skipped( - prefix=prefix, - ignored_layers=self.ignored_layers, - fused_mapping=self.packed_modules_mapping, - ): - return UnquantizedFusedMoEMethod(layer.moe_config) - - return XPUFp8MoEMethod(fp8_config, layer) - elif isinstance(layer, Attention): - return Fp8KVCacheMethod(self) - return None - def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> "QuantizeMethodBase | None": diff --git a/vllm/model_executor/layers/quantization/inc.py b/vllm/model_executor/layers/quantization/inc.py index f68fd9578..359f24688 100644 --- a/vllm/model_executor/layers/quantization/inc.py +++ b/vllm/model_executor/layers/quantization/inc.py @@ -38,7 +38,6 @@ class INCConfig(QuantizationConfig): "awq", "awq:marlin", "marlin", - "ipex", } def __init__( @@ -410,31 +409,10 @@ class INCConfig(QuantizationConfig): return UnquantizedLinearMethod() else: return None - from vllm.model_executor.layers.quantization.ipex_quant import ( - IPEXAWQLinearMethod, - IPEXConfig, - IPEXGPTQLinearMethod, + raise NotImplementedError( + "INC quantization is not supported during xpu kernel migration." ) - if isinstance(layer, (LinearBase, ParallelLMHead)): - if "awq" in self.packing_format: - config = IPEXConfig( - method="awq", weight_bits=weight_bits, group_size=group_size - ) - return IPEXAWQLinearMethod(config) - elif "gptq" in self.packing_format: - config = IPEXConfig( - method="gptq", weight_bits=weight_bits, group_size=group_size - ) - return IPEXGPTQLinearMethod(config) - else: - raise ValueError( - f"ipex backend only supports awq " - f"and gptq format,but got {self.packing_format}" - ) - else: - return None - def get_quant_method(self, layer: torch.nn.Module, prefix: str): if prefix and self.extra_config: for layer_name in self.extra_config: diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py deleted file mode 100644 index f957b3991..000000000 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ /dev/null @@ -1,403 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Any - -import torch -from packaging import version -from torch.nn import Module - -from vllm._ipex_ops import ipex_ops as ops -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.linear import ( - LinearBase, - LinearMethodBase, - UnquantizedLinearMethod, -) -from vllm.model_executor.layers.quantization import ( - QuantizationConfig, - QuantizationMethods, -) -from vllm.model_executor.layers.quantization.awq import AWQLinearMethod -from vllm.model_executor.layers.quantization.fp8 import ( - Fp8Config, - Fp8LinearMethod, - Fp8OnlineMoEMethod, -) -from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod -from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped -from vllm.model_executor.utils import replace_parameter -from vllm.platforms import current_platform - -MIN_IPEX_VERSION = "2.6.0" - - -class IPEXConfig(QuantizationConfig): - """INT8 quantization config class using IPEX for the CPU/XPU backend, - including AWQ, GPTQ. - """ - - IPEX_QUANT_METHOD_MAP = { - "awq": 1, - "gptq": 0, - } - - def __init__( - self, - method: str, - weight_bits: int, - group_size: int, - modules_to_not_convert: list[str] | None = None, - desc_act: bool | None = None, - lm_head_quantized: bool | None = None, - is_sym: bool | None = None, - ) -> None: - super().__init__() - self.method = method - self.weight_bits = weight_bits - self.group_size = group_size - self.modules_to_not_convert = modules_to_not_convert or [] - self.desc_act = desc_act - self.lm_head_quantized = lm_head_quantized - self.is_sym = is_sym - self.pack_factor = 32 // self.weight_bits - - if self.weight_bits not in [4]: - raise ValueError( - f"IPEX quantization supports weight bits [4], " - f"but got {self.weight_bits}." - ) - - if self.method not in ["awq", "gptq"]: - raise ValueError( - f"IPEX quantization supports [awq, gptq], but got {self.method}." - ) - - def __repr__(self) -> str: - return ( - f"IPEXConfig(method={self.method}," - f"weight_bits={self.weight_bits}, " - f"group_size={self.group_size})" - ) - - @classmethod - def get_name(cls) -> QuantizationMethods: - return "ipex" - - @classmethod - def get_supported_act_dtypes(cls) -> list[torch.dtype]: - return [torch.bfloat16, torch.float16] - - @classmethod - def get_min_capability(cls) -> int: - return -1 - - @staticmethod - def get_config_filenames() -> list[str]: - return [ - "quant_config.json", - "quantize_config.json", - ] - - @classmethod - def from_config(cls, config: dict[str, Any]) -> "IPEXConfig": - method = cls.get_from_keys(config, ["quant_method"]).lower() - if method == "awq": - weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) - group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) - modules_to_not_convert = cls.get_from_keys_or( - config, ["modules_to_not_convert"], None - ) - is_sym = not cls.get_from_keys_or(config, ["zero_point"], default=False) - return cls( - method, - weight_bits, - group_size, - modules_to_not_convert, - False, - False, - is_sym, - ) - # otherwise for gptq - weight_bits = cls.get_from_keys(config, ["bits"]) - group_size = cls.get_from_keys(config, ["group_size"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) - desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False) - is_sym = cls.get_from_keys_or(config, ["sym"], default=True) - return cls( - method, weight_bits, group_size, [], desc_act, lm_head_quantized, is_sym - ) - - @classmethod - def override_quantization_method( - cls, hf_quant_cfg, user_quant - ) -> QuantizationMethods | None: - if not current_platform.is_xpu(): - return None - - quant_method = hf_quant_cfg.get("quant_method", "").lower() - - if quant_method in ["awq", "gptq"]: - return cls.get_name() - - return None - - def get_quant_method( - self, layer: torch.nn.Module, prefix: str - ) -> "LinearMethodBase | None": - if isinstance(layer, LinearBase): - if self.method == "awq": - if is_layer_skipped( - prefix, - self.modules_to_not_convert, - self.packed_modules_mapping, - skip_with_substr=True, - ): - return UnquantizedLinearMethod() - return IPEXAWQLinearMethod(self) - if self.method == "gptq": - return IPEXGPTQLinearMethod(self) - return None - - -class IPEXGPTQLinearMethod(GPTQLinearMethod): - """GPTQ linear method using IPEX for the CPU/XPU backend.""" - - def __init__(self, quant_config: IPEXConfig): - self.quant_config = quant_config # type: ignore - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - bias = layer.bias if not layer.skip_bias_add else None - - try: - import intel_extension_for_pytorch as ipex - - if version.parse(ipex.__version__) < version.parse(MIN_IPEX_VERSION): - raise ImportError( - "intel_extension_for_pytorch version is " - "wrong. Please install " - f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}." - ) - except ImportError as err: - raise ImportError( - "Please install " - f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via " - f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`" - " to use IPEX-AWQ linear method." - ) from err - # Using the compute dtype (lowp_mode) as INT8 to leverage instructions - # with better performance. - lowp_mode = ipex.quantization.WoqLowpMode.INT8 - # The weight will be de-packed from INT4 to INT8. - weight_dtype = ipex.quantization.WoqWeightDtype.INT4 - # The float activation will be quantized (dynamic, per-token) to INT8. - act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK - - assert isinstance(self.quant_config, IPEXConfig) - qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( - weight_dtype=weight_dtype, - lowp_mode=lowp_mode, - act_quant_mode=act_quant_mode, - group_size=self.quant_config.group_size, - ) - layer.ipex_output_size = layer.qweight.shape[-1] - g_idx = layer.g_idx if self.quant_config.desc_act else None - layer.ipex_qlinear = ( - ipex.llm.quantization.woq_linear.IPEXWeightOnlyQuantizedLinear.from_weight( - layer.qweight, - layer.scales, - layer.qzeros, - layer.qweight.size(0), - layer.ipex_output_size, - qconfig=qconfig, - g_idx=g_idx, - bias=bias, - group_size=self.quant_config.group_size, - quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["gptq"], - weight_qscheme="sym" if self.quant_config.is_sym else "asym", - ) - ) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ) -> torch.Tensor: - reshaped_x = x.reshape(-1, x.shape[-1]) - out = layer.ipex_qlinear(reshaped_x) - return out.reshape(x.shape[:-1] + (layer.ipex_output_size,)) - - -class IPEXAWQLinearMethod(AWQLinearMethod): - """AWQ linear method using IPEX for the CPU/XPU backend.""" - - def __init__(self, quant_config: IPEXConfig): - self.quant_config = quant_config # type: ignore - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - super().process_weights_after_loading(layer=layer) - - bias = layer.bias if not layer.skip_bias_add else None - - try: - import intel_extension_for_pytorch as ipex - - if version.parse(ipex.__version__) < version.parse(MIN_IPEX_VERSION): - raise ImportError( - "intel_extension_for_pytorch version is " - "wrong. Please install " - f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}." - ) - except ImportError as err: - raise ImportError( - "Please install " - f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via " - f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`" - " to use IPEX-AWQ linear method." - ) from err - - # Using the compute dtype (lowp_mode) as INT8 to leverage instructions - # with better performance. - lowp_mode = ipex.quantization.WoqLowpMode.INT8 - # The weight will be de-packed from INT4 to INT8. - weight_dtype = ipex.quantization.WoqWeightDtype.INT4 - # The float activation will be quantized (dynamic, per-token) to INT8. - act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH - - assert isinstance(self.quant_config, IPEXConfig) - qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( - weight_dtype=weight_dtype, - lowp_mode=lowp_mode, - act_quant_mode=act_quant_mode, - group_size=self.quant_config.group_size, - ) - - layer.ipex_output_size = layer.qweight.size(1) * self.quant_config.pack_factor - layer.ipex_qlinear = ( - ipex.llm.quantization.woq_linear.IPEXWeightOnlyQuantizedLinear.from_weight( - layer.qweight, - layer.scales, - layer.qzeros, - layer.qweight.size(0), - layer.ipex_output_size, - qconfig=qconfig, - bias=bias, - group_size=self.quant_config.group_size, - quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["awq"], # type: ignore - weight_qscheme="sym" if self.quant_config.is_sym else "asym", - ) - ) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ) -> torch.Tensor: - reshaped_x = x.reshape(-1, x.shape[-1]) - out = layer.ipex_qlinear(reshaped_x) - return out.reshape(x.shape[:-1] + (layer.ipex_output_size,)) - - -class XPUFp8LinearMethod(Fp8LinearMethod): - def __init__(self, quant_config: Fp8Config): - super().__init__(quant_config) - - def process_weights_after_loading(self, layer: Module) -> None: - if getattr(layer, "_already_called_process_weights_after_loading", False): - return - # If checkpoint not serialized fp8, quantize the weights. - if not self.quant_config.is_checkpoint_fp8_serialized: - qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) - # Update the layer with the new values. - replace_parameter(layer, "weight", qweight.data) - replace_parameter(layer, "weight_scale", weight_scale.data) - layer.input_scale = None - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ) -> torch.Tensor: - weight = layer.weight.data - weight_scale = layer.weight_scale.data - output = torch.ops.torch_ipex.fp8_gemm_w8a16( - x, weight, True, weight_scale, bias - ) - return output - - -class XPUFp8MoEMethod(Fp8OnlineMoEMethod): - def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): - super().__init__(quant_config, layer) - self.quant_config = quant_config - - def process_weights_after_loading(self, layer: Module) -> None: - if getattr(layer, "_already_called_process_weights_after_loading", False): - return - if not self.quant_config.is_checkpoint_fp8_serialized: - fp8_dtype = current_platform.fp8_dtype() - w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) - w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) - - # Re-initialize w13_scale because we directly quantize - # merged w13 weights and generate a single scaling factor. - layer.w13_weight_scale = torch.nn.Parameter( - torch.ones( - layer.local_num_experts, - dtype=torch.float32, - device=w13_weight.device, - ), - requires_grad=False, - ) - for expert in range(layer.local_num_experts): - w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( - ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) - ) - w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( - ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) - ) - replace_parameter(layer, "w13_weight", w13_weight) - replace_parameter(layer, "w2_weight", w2_weight) - - import intel_extension_for_pytorch as ipex - - ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts - layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( - layer.w13_weight, - layer.w2_weight, - w1_scale_inv=layer.w13_weight_scale, - w2_scale_inv=layer.w2_weight_scale, - a1_scale_inv=layer.w13_input_scale, - a2_scale_inv=layer.w2_input_scale, - use_prepack=True, - experts_start_id=ep_rank_start, - ) - - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: - return None - - @property - def is_monolithic(self) -> bool: - return True - - def apply_monolithic( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - ) -> torch.Tensor: - return layer.ipex_fusion( - x, - layer.use_grouped_topk, - layer.top_k, - router_logits, - layer.renormalize, - layer.topk_group, - layer.num_expert_group, - custom_routing_function=layer.custom_routing_function, - ) diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index d63367af5..ffc6f67da 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -232,17 +232,14 @@ class RotaryEmbedding(RotaryEmbeddingBase): query: torch.Tensor, key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - from vllm._ipex_ops import ipex_ops as ops - self._match_cos_sin_cache_dtype(query) # ops.rotary_embedding() is an in-place operation # that updates the query and key tensors. if key is None: - # XPU kernel doesn't support key=None so fall back to native impl - # TODO(sarckk): add support for optional key in - # ipex.llm.functional.rotary_embedding_batched return self.forward_native(positions, query, key) else: + from vllm import _custom_ops as ops + ops.rotary_embedding( positions, query, diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index a0e5af1ab..758409ae1 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -132,8 +132,6 @@ def xpu_platform_plugin() -> str | None: is_xpu = False logger.debug("Checking if XPU platform is available.") try: - # installed IPEX if the machine has XPUs. - import intel_extension_for_pytorch # noqa: F401 import torch if supports_xccl(): diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 439d21cb8..6e299f30e 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -7,6 +7,11 @@ from typing import TYPE_CHECKING import torch +# import custom ops, trigger op registration +import vllm_xpu_kernels._C # noqa +import vllm_xpu_kernels._moe_C # noqa +import vllm_xpu_kernels._xpu_C # noqa + from vllm.logger import init_logger from vllm.v1.attention.backends.registry import AttentionBackendEnum @@ -55,6 +60,9 @@ class XPUPlatform(Platform): dtype = attn_selector_config.dtype if attn_selector_config.use_sparse: raise NotImplementedError("Sparse Attention is not supported on XPU.") + if attn_selector_config.use_mla: + logger.info_once("Using Triton MLA backend on V1 engine.") + return AttentionBackendEnum.TRITON_MLA.get_path() if selected_backend == AttentionBackendEnum.TRITON_ATTN: logger.info_once("Using Triton backend.") return AttentionBackendEnum.TRITON_ATTN.get_path() @@ -78,9 +86,9 @@ class XPUPlatform(Platform): @classmethod def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: - # XPU only supports FLASH_ATTN for vision attention. return [ AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, ] @classmethod @@ -145,7 +153,7 @@ class XPUPlatform(Platform): def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config = vllm_config.cache_config model_config = vllm_config.model_config - # in V1(or with ipex chunked prefill) block_size is 64 + # in V1(or with chunked prefill) block_size is 64 if cache_config and cache_config.block_size is None: cache_config.block_size = 64 @@ -206,7 +214,7 @@ class XPUPlatform(Platform): @classmethod def fp8_dtype(cls) -> torch.dtype: - return torch.float8_e5m2 + return torch.float8_e4m3fn @classmethod def is_data_center_gpu(cls) -> bool: diff --git a/vllm/v1/attention/backends/fa_utils.py b/vllm/v1/attention/backends/fa_utils.py index 988cf7c27..281d18855 100644 --- a/vllm/v1/attention/backends/fa_utils.py +++ b/vllm/v1/attention/backends/fa_utils.py @@ -16,12 +16,13 @@ if current_platform.is_cuda(): ) elif current_platform.is_xpu(): + from vllm import _custom_ops as ops + + reshape_and_cache_flash = ops.reshape_and_cache_flash from vllm._ipex_ops import ipex_ops - reshape_and_cache_flash = ipex_ops.reshape_and_cache_flash flash_attn_varlen_func = ipex_ops.flash_attn_varlen_func # type: ignore[assignment] get_scheduler_metadata = ipex_ops.get_scheduler_metadata # type: ignore[assignment] - elif current_platform.is_rocm(): try: from flash_attn import flash_attn_varlen_func # type: ignore[no-redef] diff --git a/vllm/v1/attention/backends/registry.py b/vllm/v1/attention/backends/registry.py index bd45702fa..2a80bbd94 100644 --- a/vllm/v1/attention/backends/registry.py +++ b/vllm/v1/attention/backends/registry.py @@ -69,7 +69,6 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): "vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend" ) FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend" - IPEX = "vllm.v1.attention.backends.ipex.IpexAttentionBackend" NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend" FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" TREE_ATTN = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py index f1bdd5da3..6e45a107c 100644 --- a/vllm/v1/worker/xpu_worker.py +++ b/vllm/v1/worker/xpu_worker.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc import os from typing import Any import torch -import torch.distributed from vllm.config import VllmConfig from vllm.logger import init_logger @@ -85,7 +85,14 @@ class XPUWorker(Worker): current_platform.dist_backend, ) + # Set random seed. + set_random_seed(self.model_config.seed) + + # Now take memory snapshot after NCCL is initialized + gc.collect() torch.xpu.empty_cache() + + # take current memory snapshot self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device) self.requested_memory = request_memory(init_snapshot, self.cache_config) logger.debug("worker init memory snapshot: %r", self.init_snapshot) @@ -93,9 +100,6 @@ class XPUWorker(Worker): "worker requested memory: %sGiB", format_gib(self.requested_memory) ) - # Set random seed. - set_random_seed(self.model_config.seed) - # Initialize workspace manager num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1 init_workspace_manager(self.device, num_ubatches)