Compare commits

...

7 Commits

Author SHA1 Message Date
Seiji Eicher
c44d0c6d66 Patch protobuf for CVE-2026-0994 (#34253)
Signed-off-by: Seiji Eicher <seiji@anyscale.com>
Co-authored-by: Kevin H. Luu <khluu000@gmail.com>
(cherry picked from commit 5045d5c983)
2026-02-11 02:33:40 -08:00
Kunshang Ji
83db96d8cd [XPU][9/N] clean up existing ipex code/doc (#34111)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
(cherry picked from commit cb9574eb85)
2026-02-11 02:33:27 -08:00
zofia
dbfb79fe45 [XPU][7/N] enable xpu fp8 moe (#34202)
Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>
(cherry picked from commit b482f71e9f)
2026-02-11 02:33:15 -08:00
Roger Wang
b2e1fc3589 [Bugfix][Core] Fix CPU memory leak from Request reference cycle in prefix caching (#34183)
Signed-off-by: Roger Wang <hey@rogerw.io>
(cherry picked from commit 8a5e0e2b2b)
2026-02-11 02:33:04 -08:00
Gregory Shtrasberg
55a1baebc5 [Bugfix][ROCm][GPT-OSS] Use old triton_kernels implementation on ROCm if the new API is not available (#34153)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
(cherry picked from commit c60f8e3b49)
2026-02-11 02:32:52 -08:00
Charlie Fu
e1e9841631 [torch.compile][Fusion] Fix attention fusion pass removing kv_udpate op. (#33945)
Signed-off-by: charlifu <charlifu@amd.com>
(cherry picked from commit bb9f97308d)
2026-02-11 02:32:41 -08:00
zofia
5bd63387c3 [XPU][6/N] add xpu scaled_mm kernel (#34117)
Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>
(cherry picked from commit 9bdb06b436)
2026-02-11 02:32:27 -08:00
26 changed files with 225 additions and 89 deletions

View File

@@ -39,6 +39,7 @@ docker run \
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --attention-backend=TRITON_ATTN
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --quantization fp8
python3 examples/offline_inference/basic/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager
python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2
python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel

View File

@@ -134,7 +134,6 @@ WORKDIR /vllm-workspace
# Copy test requirements
COPY requirements/test.in requirements/cpu-test.in
# TODO: Update to 2.9.0 when there is a new build for intel_extension_for_pytorch for that version
RUN \
sed -i '/mamba_ssm/d' requirements/cpu-test.in && \
remove_packages_not_supported_on_aarch64() { \

View File

@@ -6,10 +6,11 @@ vLLM initially supports basic model inference and serving on Intel GPU platform.
# --8<-- [start:requirements]
- Supported Hardware: Intel Data Center GPU, Intel ARC GPU
- OneAPI requirements: oneAPI 2025.1
- OneAPI requirements: oneAPI 2025.3
- Dependency: [vllm-xpu-kernels](https://github.com/vllm-project/vllm-xpu-kernels): a package provide all necessary vllm custom kernel when running vLLM on Intel GPU platform,
- Python: 3.12
!!! warning
The provided IPEX whl is Python3.12 specific so this version is a MUST.
The provided vllm-xpu-kernels whl is Python3.12 specific so this version is a MUST.
# --8<-- [end:requirements]
# --8<-- [start:set-up-using-python]
@@ -24,7 +25,7 @@ Currently, there are no pre-built XPU wheels.
# --8<-- [end:pre-built-wheels]
# --8<-- [start:build-wheel-from-source]
- First, install required [driver](https://dgpu-docs.intel.com/driver/installation.html#installing-gpu-drivers) and [Intel OneAPI](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) 2025.1 or later.
- First, install required [driver](https://dgpu-docs.intel.com/driver/installation.html#installing-gpu-drivers) and [Intel OneAPI](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) 2025.3 or later.
- Second, install Python packages for vLLM XPU backend building:
```bash
@@ -37,7 +38,7 @@ pip install -v -r requirements/xpu.txt
- Then, build and install vLLM XPU backend:
```bash
VLLM_TARGET_DEVICE=xpu python setup.py install
VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -e . -v
```
# --8<-- [end:build-wheel-from-source]

View File

@@ -9,5 +9,5 @@ wheel
jinja2>=3.1.6
regex
build
protobuf
protobuf >= 5.29.6, !=6.30.*, !=6.31.*, !=6.32.*, !=6.33.0.*, !=6.33.1.*, !=6.33.2.*, !=6.33.3.*, !=6.33.4.*
grpcio-tools==1.78.0 # Required for grpc entrypoints

View File

@@ -9,7 +9,7 @@ blake3
py-cpuinfo
transformers >= 4.56.0, < 5
tokenizers >= 0.21.1 # Required for fast incremental detokenization.
protobuf # Required by LlamaTokenizer, gRPC.
protobuf >= 5.29.6, !=6.30.*, !=6.31.*, !=6.32.*, !=6.33.0.*, !=6.33.1.*, !=6.33.2.*, !=6.33.3.*, !=6.33.4.* # Required by LlamaTokenizer, gRPC. CVE-2026-0994
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
aiohttp >= 3.13.3
openai >= 1.99.1 # For Responses API with reasoning content

View File

@@ -15,4 +15,4 @@ torch==2.10.0+xpu
torchaudio
torchvision
vllm_xpu_kernels @ https://github.com/vllm-project/vllm-xpu-kernels/releases/download/v0.1.1/vllm_xpu_kernels-0.1.1-cp312-cp312-linux_x86_64.whl
vllm_xpu_kernels @ https://github.com/vllm-project/vllm-xpu-kernels/releases/download/v0.1.2/vllm_xpu_kernels-0.1.2-cp312-cp312-linux_x86_64.whl

View File

@@ -267,7 +267,7 @@ elif current_platform.is_rocm():
PATTERN_TEST_MODELS_FP8 = [
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
]
BACKENDS = [
BACKENDS_FP8 = [
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
AttentionBackendEnum.ROCM_ATTN,
AttentionBackendEnum.TRITON_ATTN,
@@ -474,6 +474,17 @@ def test_attention_quant_pattern(
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
"Attention should not have output_block_scale before fusion"
)
kv_cache_dummy_dep_pre_is_none = (
attn_nodes_pre[0].kwargs.get("kv_cache_dummy_dep") is None
)
kv_cache_dummy_dep_post_is_none = (
attn_nodes_post[0].kwargs.get("kv_cache_dummy_dep") is None
)
assert not (kv_cache_dummy_dep_pre_is_none ^ kv_cache_dummy_dep_post_is_none), (
"The kv_cache_dummy_dep should be consistent before and after fusion"
)
if quant_key.dtype == FP8_DTYPE:
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
"Attention should not have output_block_scale after FP8 fusion"

View File

@@ -17,7 +17,7 @@ DTYPE = ["bfloat16"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", DTYPE)
def test_ipex_quant(vllm_runner, model, dtype):
def test_cpu_quant(vllm_runner, model, dtype):
with vllm_runner(model, dtype=dtype) as llm:
output = llm.generate_greedy(["The capital of France is"], max_tokens=32)
assert output

View File

@@ -1,32 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test model set-up and inference for quantized HF models supported
on the CPU/GPU backend using IPEX (including AWQ/GPTQ).
Validating the configuration and printing results for manual checking.
Run `pytest tests/quantization/test_ipex_quant.py`.
"""
import pytest
from vllm.platforms import current_platform
MODELS = [
"AMead10/Llama-3.2-1B-Instruct-AWQ",
"shuyuej/Llama-3.2-1B-Instruct-GPTQ", # with g_idx
]
DTYPE = ["bfloat16"]
@pytest.mark.skipif(
not current_platform.is_cpu() and not current_platform.is_xpu(),
reason="only supports Intel CPU/XPU backend.",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", DTYPE)
def test_ipex_quant(vllm_runner, model, dtype):
with vllm_runner(model, dtype=dtype, enforce_eager=True) as llm:
output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
assert output
print(output)

View File

@@ -236,7 +236,7 @@ def test_prefix_caching_for_multi_turn():
req._all_token_ids = req.prompt_token_ids.copy()
req.all_token_ids = ConstantList(req._all_token_ids)
req.block_hashes = []
req.block_hashes = req.get_hash_new_full_blocks()
req.update_block_hashes()
# Schedule the next-turn requests.
for req in next_turn_requests:

View File

@@ -53,7 +53,7 @@ if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"):
return torch.empty((M, N), dtype=input.dtype, device=input.device)
class ipex_ops:
class xpu_ops:
@staticmethod
def flash_attn_varlen_func(
q: torch.Tensor,
@@ -73,7 +73,7 @@ class ipex_ops:
cu_seqlens_k: torch.Tensor | None = None,
# passed in qwen vl
dropout_p: float = 0.0,
# The following parameters are not used in ipex kernel currently,
# The following parameters are not used in xpu kernel currently,
# we keep API compatible to CUDA's.
scheduler_metadata=None,
fa_version: int = 2,
@@ -153,6 +153,6 @@ class ipex_ops:
sm_margin=0, # Can be tuned if some SMs are used for communication
) -> None:
logger.warning_once(
"get_scheduler_metadata is not implemented for ipex_ops, returning None."
"get_scheduler_metadata is not implemented for xpu_ops, returning None."
)
return None

View File

@@ -142,6 +142,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
v: torch.Tensor,
output_attn: torch.Tensor,
scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> torch.Tensor:
at1 = auto_functionalized(
ATTN_OP,
@@ -152,6 +153,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
layer_name=self.layer_name,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size]
@@ -165,6 +167,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
v: torch.Tensor,
output_attn: torch.Tensor,
scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> torch.Tensor:
# attn output in quant_dtype
output_attn = torch.ops.aten.full.default(
@@ -182,6 +185,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
layer_name=self.layer_name,
output_scale=scale,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
@@ -191,6 +195,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
self.empty(5, self.num_heads, self.head_size), # v
self.empty(5, self.num_heads, self.head_size), # attn_output
empty_fp32(1, 1), # scale
self.empty(0), # kv_cache_dummy_dep
]
pm.register_replacement(
@@ -228,6 +233,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
output_quant: torch.Tensor,
output_scale: torch.Tensor,
input_scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at1 = auto_functionalized(
ATTN_OP,
@@ -238,6 +244,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
layer_name=self.layer_name,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size]
@@ -261,6 +268,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
output_quant: torch.Tensor,
output_scale: torch.Tensor,
input_scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# attention output in quant_dtype
output_attn = torch.ops.aten.full.default(
@@ -280,6 +288,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
layer_name=self.layer_name,
output_scale=input_scale,
output_block_scale=output_scale_view,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2])
return output, at2[2]
@@ -294,6 +303,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
128, round_up(self.num_heads * self.head_size // 16, 4)
), # output_scale
empty_fp32(1, 1), # input_scale
self.empty(0), # kv_cache_dummy_dep
]
pm.register_replacement(

View File

@@ -102,6 +102,7 @@ if HAS_TRITON:
)
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import (
XPUExperts,
XPUExpertsFp8,
)
__all__ += [
@@ -121,6 +122,7 @@ if HAS_TRITON:
"BatchedDeepGemmExperts",
"TritonOrDeepGemmExperts",
"XPUExperts",
"XPUExpertsFp8",
]
else:
# Some model classes directly use the custom ops. Add placeholders

View File

@@ -19,11 +19,14 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.import_utils import has_triton_kernels
logger = init_logger(__name__)
use_legacy_triton_kernels = False
if has_triton_kernels():
try:
import triton_kernels.swiglu
@@ -38,10 +41,20 @@ if has_triton_kernels():
from triton_kernels.tensor import (
BIT,
Bitmatrix,
SparseMatrix,
make_ragged_tensor_metadata,
)
from triton_kernels.topk import topk
try:
from triton_kernels.tensor import (
SparseMatrix,
make_ragged_tensor_metadata,
)
except ImportError:
if current_platform.is_rocm():
logger.warning_once("Using legacy triton_kernels on ROCm")
use_legacy_triton_kernels = True
else:
raise
except (AttributeError, ImportError) as e:
logger.error(
"Failed to import Triton kernels. Please make sure your triton "
@@ -101,6 +114,12 @@ def legacy_routing_from_bitmatrix(
Replacement for the removed triton_kernels.routing.routing_from_bitmatrix.
Creates routing data from a bitmatrix representation.
"""
if use_legacy_triton_kernels:
from triton_kernels.routing import routing_from_bitmatrix
return routing_from_bitmatrix(
bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act
)
sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix)
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
@@ -130,6 +149,10 @@ def legacy_routing(
Replacement for the removed triton_kernels.routing.routing function.
Computes routing data from gating logits.
"""
if use_legacy_triton_kernels:
from triton_kernels.routing import routing
return routing(logits, n_expts_act, sm_first=sm_first)
if sm_first:
logits = torch.softmax(logits, dim=-1)
sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first)
@@ -231,11 +254,22 @@ def triton_kernel_fused_experts(
)
output_tensor = _resize_cache(output_tensor, (batch_dim, M, K))
act = FusedActivation(
FnSpecs(
"swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit"), reduction_n=2
),
(swiglu_alpha, swiglu_limit),
act = (
FusedActivation(
FnSpecs(
"swiglu",
triton_kernels.swiglu.swiglu_fn,
("alpha", "limit"),
reduction_n=2,
),
(swiglu_alpha, swiglu_limit),
)
if not use_legacy_triton_kernels
else FusedActivation(
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
(swiglu_alpha, swiglu_limit),
2,
)
)
gammas = routing_data.gate_scal if routing_data else None
@@ -296,8 +330,17 @@ def make_routing_data(
bitmatrix_shape = [n_rows, bm_cols * 32]
bitmatrix_shape_max = [n_rows, None]
bitmatrix = Bitmatrix(
bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max
bitmatrix = (
Bitmatrix(
bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max
)
if not use_legacy_triton_kernels
else Bitmatrix(
bitmatrix,
shape=bitmatrix_shape,
shape_max=bitmatrix_shape_max,
scratchpad=None,
)
)
# matmul_ogs expects invalid topk_weights to be -1s

View File

@@ -52,6 +52,7 @@ class Fp8MoeBackend(Enum):
AITER = "AITER"
VLLM_CUTLASS = "VLLM_CUTLASS"
BATCHED_VLLM_CUTLASS = "BATCHED_VLLM_CUTLASS"
XPU = "XPU"
def backend_to_kernel_cls(
@@ -123,6 +124,13 @@ def backend_to_kernel_cls(
return CutlassBatchedExpertsFp8
elif backend == Fp8MoeBackend.XPU:
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import (
XPUExpertsFp8,
)
return XPUExpertsFp8
else:
raise ValueError(f"Unknown FP8 MoE backend: {backend.value}")
@@ -154,6 +162,7 @@ def select_fp8_moe_backend(
Fp8MoeBackend.TRITON,
Fp8MoeBackend.BATCHED_TRITON,
Fp8MoeBackend.MARLIN,
Fp8MoeBackend.XPU,
]
# NOTE(rob): We need to peak into the P/F selection to determine
@@ -393,6 +402,7 @@ def convert_to_fp8_moe_kernel_format(
Fp8MoeBackend.BATCHED_TRITON,
Fp8MoeBackend.VLLM_CUTLASS,
Fp8MoeBackend.BATCHED_VLLM_CUTLASS,
Fp8MoeBackend.XPU,
]:
raise ValueError(f"Unsupported FP8 MoE backend: {fp8_backend.value}")

View File

@@ -4,13 +4,16 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8DynamicTensorSym,
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
@@ -20,6 +23,21 @@ if current_platform.is_xpu():
class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int | None = None,
num_dispatchers: int | None = None,
):
super().__init__(
moe_config,
quant_config,
max_num_tokens,
num_dispatchers,
)
self.is_fp8 = False
@property
def expects_unquantized_inputs(self) -> bool:
return True
@@ -49,10 +67,10 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
# TODO: dispatch based on device.
SUPPORTED_W_A = [
(None, None),
(kFp8StaticTensorSym, None),
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@@ -103,10 +121,10 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
xpu_fused_moe(
hidden_states=hidden_states,
w13=w1,
w13_scales=a1q_scale,
w13_scales=self.w1_scale,
w13_bias=self.w1_bias,
w2=w2,
w2_scales=a2_scale,
w2_scales=self.w2_scale,
w2_bias=self.w2_bias,
topk_weights=topk_weights,
topk_ids=topk_ids,
@@ -116,5 +134,22 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
ep_rank=self.moe_config.ep_rank,
ep_size=self.moe_config.ep_size,
output=output,
is_fp8=self.is_fp8,
)
return
class XPUExpertsFp8(XPUExperts):
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int | None = None,
num_dispatchers: int | None = None,
):
super().__init__(
moe_config,
quant_config,
max_num_tokens,
num_dispatchers,
)
self.is_fp8 = True

View File

@@ -180,18 +180,9 @@ class Fp8Config(QuantizationConfig):
weight_block_size=weight_block_size,
)
def get_xpu_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
raise NotImplementedError(
"FP8 quantization is not supported during xpu kernel migration."
)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
if current_platform.is_xpu():
return self.get_xpu_quant_method(layer, prefix)
if isinstance(layer, LinearBase):
if is_layer_skipped(
prefix=prefix,
@@ -300,7 +291,7 @@ class Fp8LinearMethod(LinearMethodBase):
or envs.VLLM_TEST_FORCE_FP8_MARLIN
)
# Disable marlin for rocm
if current_platform.is_rocm():
if current_platform.is_rocm() or current_platform.is_xpu():
self.use_marlin = False
if vllm_is_batch_invariant():
self.use_marlin = False

View File

@@ -39,6 +39,9 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
TritonInt8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xpu import (
XPUFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.platforms import PlatformEnum, current_platform
@@ -72,6 +75,9 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
PerTensorTorchFP8ScaledMMLinearKernel,
ChannelWiseTorchFP8ScaledMMLinearKernel,
],
PlatformEnum.XPU: [
XPUFP8ScaledMMLinearKernel,
],
}
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)

View File

@@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
import torch
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
)
from vllm.platforms import current_platform
class XPUFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_xpu():
return False, "XPUFP8ScaledMM only support on XPU"
return True, None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if c.weight_quant_key.dtype not in {torch.float8_e5m2, torch.float8_e4m3fn}:
return False, "XPUFP8ScaledMM only support FP8 weight dtype"
return True, None
def __init__(
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
) -> None:
assert self.can_implement(c)[0]
assert self.is_supported()[0]
self.config = c
self.layer_param_names = layer_param_names
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
weight = layer.weight
weight_scale = layer.weight_scale
return torch.ops._xpu_C.fp8_gemm_w8a16(x, weight, weight_scale, bias)
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
pass

View File

@@ -160,7 +160,7 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
logger.info_once("Using Triton backend")
return Mxfp4Backend.TRITON
elif current_platform.is_xpu():
logger.info_once("Using ipex marlin backend on XPU")
logger.info_once("Using xpu backend on XPU")
return Mxfp4Backend.MARLIN
elif current_platform.is_rocm() and has_triton_kernels():
logger.info_once("Using Triton backend")

View File

@@ -20,7 +20,7 @@ from vllm.v1.worker.workspace import current_workspace_manager
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
from vllm._xpu_ops import xpu_ops as ops
logger = init_logger(__name__)

View File

@@ -338,7 +338,6 @@ class CpuPlatform(Platform):
ld_preload_str += pytorch_libgomp_so
os.environ["LD_PRELOAD"] = ld_preload_str
# To hint IPEX uses shared memory based AllReduce
os.environ["LOCAL_WORLD_SIZE"] = str(
vllm_config.parallel_config.tensor_parallel_size
)

View File

@@ -23,12 +23,11 @@ if current_platform.is_cuda():
elif current_platform.is_xpu():
from vllm import _custom_ops as ops
from vllm._xpu_ops import xpu_ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
from vllm._ipex_ops import ipex_ops
flash_attn_varlen_func = ipex_ops.flash_attn_varlen_func # type: ignore[assignment]
get_scheduler_metadata = ipex_ops.get_scheduler_metadata # type: ignore[assignment]
flash_attn_varlen_func = xpu_ops.flash_attn_varlen_func # type: ignore[assignment]
get_scheduler_metadata = xpu_ops.get_scheduler_metadata # type: ignore[assignment]
elif current_platform.is_rocm():
try:
from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
@@ -153,7 +152,7 @@ def is_flash_attn_varlen_func_available() -> bool:
Platform-specific sources:
- CUDA: vllm.vllm_flash_attn.flash_attn_varlen_func
- XPU: ipex_ops.flash_attn_varlen_func
- XPU: xpu_ops.flash_attn_varlen_func
- ROCm: upstream flash_attn.flash_attn_varlen_func (if available)
Note: This is separate from the AITER flash attention backend (rocm_aiter_fa.py)

View File

@@ -9,7 +9,7 @@ from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops # type: ignore[no-redef]
from vllm._xpu_ops import xpu_ops as ops # type: ignore[no-redef]
class PagedAttention:

View File

@@ -982,10 +982,8 @@ class Scheduler(SchedulerInterface):
session._all_token_ids.extend(update.prompt_token_ids or ())
session.prompt_token_ids.extend(update.prompt_token_ids or ())
# Update block hashes for the new tokens
# (mirrors Request.append_output_token_ids)
if session.get_hash_new_full_blocks is not None:
session.block_hashes.extend(session.get_hash_new_full_blocks())
# Update block hashes for the new tokens.
session.update_block_hashes()
session.num_prompt_tokens = len(session.prompt_token_ids)
session.arrival_time = update.arrival_time
session.sampling_params = update.sampling_params

View File

@@ -6,7 +6,6 @@ import time
from collections import deque
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Any
import torch
@@ -164,10 +163,11 @@ class Request:
self.num_external_computed_tokens = 0
self.block_hashes: list[BlockHash] = []
self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None
if block_hasher is not None:
self.get_hash_new_full_blocks = partial(block_hasher, self)
self.block_hashes = self.get_hash_new_full_blocks()
# Store the block hasher without binding self to avoid creating a
# reference cycle (Request -> partial -> Request) that prevents
# immediate garbage collection via reference counting.
self._block_hasher: Callable[[Request], list[BlockHash]] | None = block_hasher
self.update_block_hashes()
self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()
@@ -212,8 +212,12 @@ class Request:
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)
if self.get_hash_new_full_blocks is not None:
self.block_hashes.extend(self.get_hash_new_full_blocks())
self.update_block_hashes()
def update_block_hashes(self) -> None:
"""Compute block hashes for any new full blocks and append them."""
if self._block_hasher is not None:
self.block_hashes.extend(self._block_hasher(self))
@property
def use_structured_output(self) -> bool: