Compare commits
7 Commits
v0.17.2rc0
...
v0.16.0rc2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c44d0c6d66 | ||
|
|
83db96d8cd | ||
|
|
dbfb79fe45 | ||
|
|
b2e1fc3589 | ||
|
|
55a1baebc5 | ||
|
|
e1e9841631 | ||
|
|
5bd63387c3 |
@@ -39,6 +39,7 @@ docker run \
|
|||||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
|
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
|
||||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
|
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
|
||||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --attention-backend=TRITON_ATTN
|
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --attention-backend=TRITON_ATTN
|
||||||
|
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --quantization fp8
|
||||||
python3 examples/offline_inference/basic/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager
|
python3 examples/offline_inference/basic/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager
|
||||||
python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2
|
python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2
|
||||||
python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel
|
python3 examples/offline_inference/basic/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel
|
||||||
|
|||||||
@@ -134,7 +134,6 @@ WORKDIR /vllm-workspace
|
|||||||
# Copy test requirements
|
# Copy test requirements
|
||||||
COPY requirements/test.in requirements/cpu-test.in
|
COPY requirements/test.in requirements/cpu-test.in
|
||||||
|
|
||||||
# TODO: Update to 2.9.0 when there is a new build for intel_extension_for_pytorch for that version
|
|
||||||
RUN \
|
RUN \
|
||||||
sed -i '/mamba_ssm/d' requirements/cpu-test.in && \
|
sed -i '/mamba_ssm/d' requirements/cpu-test.in && \
|
||||||
remove_packages_not_supported_on_aarch64() { \
|
remove_packages_not_supported_on_aarch64() { \
|
||||||
|
|||||||
@@ -6,10 +6,11 @@ vLLM initially supports basic model inference and serving on Intel GPU platform.
|
|||||||
# --8<-- [start:requirements]
|
# --8<-- [start:requirements]
|
||||||
|
|
||||||
- Supported Hardware: Intel Data Center GPU, Intel ARC GPU
|
- Supported Hardware: Intel Data Center GPU, Intel ARC GPU
|
||||||
- OneAPI requirements: oneAPI 2025.1
|
- OneAPI requirements: oneAPI 2025.3
|
||||||
|
- Dependency: [vllm-xpu-kernels](https://github.com/vllm-project/vllm-xpu-kernels): a package provide all necessary vllm custom kernel when running vLLM on Intel GPU platform,
|
||||||
- Python: 3.12
|
- Python: 3.12
|
||||||
!!! warning
|
!!! warning
|
||||||
The provided IPEX whl is Python3.12 specific so this version is a MUST.
|
The provided vllm-xpu-kernels whl is Python3.12 specific so this version is a MUST.
|
||||||
|
|
||||||
# --8<-- [end:requirements]
|
# --8<-- [end:requirements]
|
||||||
# --8<-- [start:set-up-using-python]
|
# --8<-- [start:set-up-using-python]
|
||||||
@@ -24,7 +25,7 @@ Currently, there are no pre-built XPU wheels.
|
|||||||
# --8<-- [end:pre-built-wheels]
|
# --8<-- [end:pre-built-wheels]
|
||||||
# --8<-- [start:build-wheel-from-source]
|
# --8<-- [start:build-wheel-from-source]
|
||||||
|
|
||||||
- First, install required [driver](https://dgpu-docs.intel.com/driver/installation.html#installing-gpu-drivers) and [Intel OneAPI](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) 2025.1 or later.
|
- First, install required [driver](https://dgpu-docs.intel.com/driver/installation.html#installing-gpu-drivers) and [Intel OneAPI](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) 2025.3 or later.
|
||||||
- Second, install Python packages for vLLM XPU backend building:
|
- Second, install Python packages for vLLM XPU backend building:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -37,7 +38,7 @@ pip install -v -r requirements/xpu.txt
|
|||||||
- Then, build and install vLLM XPU backend:
|
- Then, build and install vLLM XPU backend:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
VLLM_TARGET_DEVICE=xpu python setup.py install
|
VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -e . -v
|
||||||
```
|
```
|
||||||
|
|
||||||
# --8<-- [end:build-wheel-from-source]
|
# --8<-- [end:build-wheel-from-source]
|
||||||
|
|||||||
@@ -9,5 +9,5 @@ wheel
|
|||||||
jinja2>=3.1.6
|
jinja2>=3.1.6
|
||||||
regex
|
regex
|
||||||
build
|
build
|
||||||
protobuf
|
protobuf >= 5.29.6, !=6.30.*, !=6.31.*, !=6.32.*, !=6.33.0.*, !=6.33.1.*, !=6.33.2.*, !=6.33.3.*, !=6.33.4.*
|
||||||
grpcio-tools==1.78.0 # Required for grpc entrypoints
|
grpcio-tools==1.78.0 # Required for grpc entrypoints
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ blake3
|
|||||||
py-cpuinfo
|
py-cpuinfo
|
||||||
transformers >= 4.56.0, < 5
|
transformers >= 4.56.0, < 5
|
||||||
tokenizers >= 0.21.1 # Required for fast incremental detokenization.
|
tokenizers >= 0.21.1 # Required for fast incremental detokenization.
|
||||||
protobuf # Required by LlamaTokenizer, gRPC.
|
protobuf >= 5.29.6, !=6.30.*, !=6.31.*, !=6.32.*, !=6.33.0.*, !=6.33.1.*, !=6.33.2.*, !=6.33.3.*, !=6.33.4.* # Required by LlamaTokenizer, gRPC. CVE-2026-0994
|
||||||
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
|
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
|
||||||
aiohttp >= 3.13.3
|
aiohttp >= 3.13.3
|
||||||
openai >= 1.99.1 # For Responses API with reasoning content
|
openai >= 1.99.1 # For Responses API with reasoning content
|
||||||
|
|||||||
@@ -15,4 +15,4 @@ torch==2.10.0+xpu
|
|||||||
torchaudio
|
torchaudio
|
||||||
torchvision
|
torchvision
|
||||||
|
|
||||||
vllm_xpu_kernels @ https://github.com/vllm-project/vllm-xpu-kernels/releases/download/v0.1.1/vllm_xpu_kernels-0.1.1-cp312-cp312-linux_x86_64.whl
|
vllm_xpu_kernels @ https://github.com/vllm-project/vllm-xpu-kernels/releases/download/v0.1.2/vllm_xpu_kernels-0.1.2-cp312-cp312-linux_x86_64.whl
|
||||||
|
|||||||
@@ -267,7 +267,7 @@ elif current_platform.is_rocm():
|
|||||||
PATTERN_TEST_MODELS_FP8 = [
|
PATTERN_TEST_MODELS_FP8 = [
|
||||||
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
|
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
|
||||||
]
|
]
|
||||||
BACKENDS = [
|
BACKENDS_FP8 = [
|
||||||
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
|
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
|
||||||
AttentionBackendEnum.ROCM_ATTN,
|
AttentionBackendEnum.ROCM_ATTN,
|
||||||
AttentionBackendEnum.TRITON_ATTN,
|
AttentionBackendEnum.TRITON_ATTN,
|
||||||
@@ -474,6 +474,17 @@ def test_attention_quant_pattern(
|
|||||||
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
|
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
|
||||||
"Attention should not have output_block_scale before fusion"
|
"Attention should not have output_block_scale before fusion"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
kv_cache_dummy_dep_pre_is_none = (
|
||||||
|
attn_nodes_pre[0].kwargs.get("kv_cache_dummy_dep") is None
|
||||||
|
)
|
||||||
|
kv_cache_dummy_dep_post_is_none = (
|
||||||
|
attn_nodes_post[0].kwargs.get("kv_cache_dummy_dep") is None
|
||||||
|
)
|
||||||
|
assert not (kv_cache_dummy_dep_pre_is_none ^ kv_cache_dummy_dep_post_is_none), (
|
||||||
|
"The kv_cache_dummy_dep should be consistent before and after fusion"
|
||||||
|
)
|
||||||
|
|
||||||
if quant_key.dtype == FP8_DTYPE:
|
if quant_key.dtype == FP8_DTYPE:
|
||||||
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
|
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
|
||||||
"Attention should not have output_block_scale after FP8 fusion"
|
"Attention should not have output_block_scale after FP8 fusion"
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ DTYPE = ["bfloat16"]
|
|||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", DTYPE)
|
@pytest.mark.parametrize("dtype", DTYPE)
|
||||||
def test_ipex_quant(vllm_runner, model, dtype):
|
def test_cpu_quant(vllm_runner, model, dtype):
|
||||||
with vllm_runner(model, dtype=dtype) as llm:
|
with vllm_runner(model, dtype=dtype) as llm:
|
||||||
output = llm.generate_greedy(["The capital of France is"], max_tokens=32)
|
output = llm.generate_greedy(["The capital of France is"], max_tokens=32)
|
||||||
assert output
|
assert output
|
||||||
|
|||||||
@@ -1,32 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
"""Test model set-up and inference for quantized HF models supported
|
|
||||||
on the CPU/GPU backend using IPEX (including AWQ/GPTQ).
|
|
||||||
|
|
||||||
Validating the configuration and printing results for manual checking.
|
|
||||||
|
|
||||||
Run `pytest tests/quantization/test_ipex_quant.py`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
MODELS = [
|
|
||||||
"AMead10/Llama-3.2-1B-Instruct-AWQ",
|
|
||||||
"shuyuej/Llama-3.2-1B-Instruct-GPTQ", # with g_idx
|
|
||||||
]
|
|
||||||
DTYPE = ["bfloat16"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not current_platform.is_cpu() and not current_platform.is_xpu(),
|
|
||||||
reason="only supports Intel CPU/XPU backend.",
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
|
||||||
@pytest.mark.parametrize("dtype", DTYPE)
|
|
||||||
def test_ipex_quant(vllm_runner, model, dtype):
|
|
||||||
with vllm_runner(model, dtype=dtype, enforce_eager=True) as llm:
|
|
||||||
output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
|
|
||||||
assert output
|
|
||||||
print(output)
|
|
||||||
@@ -236,7 +236,7 @@ def test_prefix_caching_for_multi_turn():
|
|||||||
req._all_token_ids = req.prompt_token_ids.copy()
|
req._all_token_ids = req.prompt_token_ids.copy()
|
||||||
req.all_token_ids = ConstantList(req._all_token_ids)
|
req.all_token_ids = ConstantList(req._all_token_ids)
|
||||||
req.block_hashes = []
|
req.block_hashes = []
|
||||||
req.block_hashes = req.get_hash_new_full_blocks()
|
req.update_block_hashes()
|
||||||
|
|
||||||
# Schedule the next-turn requests.
|
# Schedule the next-turn requests.
|
||||||
for req in next_turn_requests:
|
for req in next_turn_requests:
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"):
|
|||||||
return torch.empty((M, N), dtype=input.dtype, device=input.device)
|
return torch.empty((M, N), dtype=input.dtype, device=input.device)
|
||||||
|
|
||||||
|
|
||||||
class ipex_ops:
|
class xpu_ops:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def flash_attn_varlen_func(
|
def flash_attn_varlen_func(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
@@ -73,7 +73,7 @@ class ipex_ops:
|
|||||||
cu_seqlens_k: torch.Tensor | None = None,
|
cu_seqlens_k: torch.Tensor | None = None,
|
||||||
# passed in qwen vl
|
# passed in qwen vl
|
||||||
dropout_p: float = 0.0,
|
dropout_p: float = 0.0,
|
||||||
# The following parameters are not used in ipex kernel currently,
|
# The following parameters are not used in xpu kernel currently,
|
||||||
# we keep API compatible to CUDA's.
|
# we keep API compatible to CUDA's.
|
||||||
scheduler_metadata=None,
|
scheduler_metadata=None,
|
||||||
fa_version: int = 2,
|
fa_version: int = 2,
|
||||||
@@ -153,6 +153,6 @@ class ipex_ops:
|
|||||||
sm_margin=0, # Can be tuned if some SMs are used for communication
|
sm_margin=0, # Can be tuned if some SMs are used for communication
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"get_scheduler_metadata is not implemented for ipex_ops, returning None."
|
"get_scheduler_metadata is not implemented for xpu_ops, returning None."
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
@@ -142,6 +142,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
|||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
output_attn: torch.Tensor,
|
output_attn: torch.Tensor,
|
||||||
scale: torch.Tensor,
|
scale: torch.Tensor,
|
||||||
|
kv_cache_dummy_dep: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
at1 = auto_functionalized(
|
at1 = auto_functionalized(
|
||||||
ATTN_OP,
|
ATTN_OP,
|
||||||
@@ -152,6 +153,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
|||||||
layer_name=self.layer_name,
|
layer_name=self.layer_name,
|
||||||
output_scale=None,
|
output_scale=None,
|
||||||
output_block_scale=None,
|
output_block_scale=None,
|
||||||
|
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||||
)
|
)
|
||||||
attn_out_view = RESHAPE_OP(
|
attn_out_view = RESHAPE_OP(
|
||||||
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
||||||
@@ -165,6 +167,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
|||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
output_attn: torch.Tensor,
|
output_attn: torch.Tensor,
|
||||||
scale: torch.Tensor,
|
scale: torch.Tensor,
|
||||||
|
kv_cache_dummy_dep: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# attn output in quant_dtype
|
# attn output in quant_dtype
|
||||||
output_attn = torch.ops.aten.full.default(
|
output_attn = torch.ops.aten.full.default(
|
||||||
@@ -182,6 +185,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
|||||||
layer_name=self.layer_name,
|
layer_name=self.layer_name,
|
||||||
output_scale=scale,
|
output_scale=scale,
|
||||||
output_block_scale=None,
|
output_block_scale=None,
|
||||||
|
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||||
)
|
)
|
||||||
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
||||||
|
|
||||||
@@ -191,6 +195,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
|||||||
self.empty(5, self.num_heads, self.head_size), # v
|
self.empty(5, self.num_heads, self.head_size), # v
|
||||||
self.empty(5, self.num_heads, self.head_size), # attn_output
|
self.empty(5, self.num_heads, self.head_size), # attn_output
|
||||||
empty_fp32(1, 1), # scale
|
empty_fp32(1, 1), # scale
|
||||||
|
self.empty(0), # kv_cache_dummy_dep
|
||||||
]
|
]
|
||||||
|
|
||||||
pm.register_replacement(
|
pm.register_replacement(
|
||||||
@@ -228,6 +233,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
|||||||
output_quant: torch.Tensor,
|
output_quant: torch.Tensor,
|
||||||
output_scale: torch.Tensor,
|
output_scale: torch.Tensor,
|
||||||
input_scale: torch.Tensor,
|
input_scale: torch.Tensor,
|
||||||
|
kv_cache_dummy_dep: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
at1 = auto_functionalized(
|
at1 = auto_functionalized(
|
||||||
ATTN_OP,
|
ATTN_OP,
|
||||||
@@ -238,6 +244,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
|||||||
layer_name=self.layer_name,
|
layer_name=self.layer_name,
|
||||||
output_scale=None,
|
output_scale=None,
|
||||||
output_block_scale=None,
|
output_block_scale=None,
|
||||||
|
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||||
)
|
)
|
||||||
attn_out_view = RESHAPE_OP(
|
attn_out_view = RESHAPE_OP(
|
||||||
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
||||||
@@ -261,6 +268,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
|||||||
output_quant: torch.Tensor,
|
output_quant: torch.Tensor,
|
||||||
output_scale: torch.Tensor,
|
output_scale: torch.Tensor,
|
||||||
input_scale: torch.Tensor,
|
input_scale: torch.Tensor,
|
||||||
|
kv_cache_dummy_dep: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
# attention output in quant_dtype
|
# attention output in quant_dtype
|
||||||
output_attn = torch.ops.aten.full.default(
|
output_attn = torch.ops.aten.full.default(
|
||||||
@@ -280,6 +288,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
|||||||
layer_name=self.layer_name,
|
layer_name=self.layer_name,
|
||||||
output_scale=input_scale,
|
output_scale=input_scale,
|
||||||
output_block_scale=output_scale_view,
|
output_block_scale=output_scale_view,
|
||||||
|
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||||
)
|
)
|
||||||
output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2])
|
output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2])
|
||||||
return output, at2[2]
|
return output, at2[2]
|
||||||
@@ -294,6 +303,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
|||||||
128, round_up(self.num_heads * self.head_size // 16, 4)
|
128, round_up(self.num_heads * self.head_size // 16, 4)
|
||||||
), # output_scale
|
), # output_scale
|
||||||
empty_fp32(1, 1), # input_scale
|
empty_fp32(1, 1), # input_scale
|
||||||
|
self.empty(0), # kv_cache_dummy_dep
|
||||||
]
|
]
|
||||||
|
|
||||||
pm.register_replacement(
|
pm.register_replacement(
|
||||||
|
|||||||
@@ -102,6 +102,7 @@ if HAS_TRITON:
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import (
|
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import (
|
||||||
XPUExperts,
|
XPUExperts,
|
||||||
|
XPUExpertsFp8,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ += [
|
__all__ += [
|
||||||
@@ -121,6 +122,7 @@ if HAS_TRITON:
|
|||||||
"BatchedDeepGemmExperts",
|
"BatchedDeepGemmExperts",
|
||||||
"TritonOrDeepGemmExperts",
|
"TritonOrDeepGemmExperts",
|
||||||
"XPUExperts",
|
"XPUExperts",
|
||||||
|
"XPUExpertsFp8",
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
# Some model classes directly use the custom ops. Add placeholders
|
# Some model classes directly use the custom ops. Add placeholders
|
||||||
|
|||||||
@@ -19,11 +19,14 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
|||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
QuantKey,
|
QuantKey,
|
||||||
)
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils.import_utils import has_triton_kernels
|
from vllm.utils.import_utils import has_triton_kernels
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
use_legacy_triton_kernels = False
|
||||||
|
|
||||||
if has_triton_kernels():
|
if has_triton_kernels():
|
||||||
try:
|
try:
|
||||||
import triton_kernels.swiglu
|
import triton_kernels.swiglu
|
||||||
@@ -38,10 +41,20 @@ if has_triton_kernels():
|
|||||||
from triton_kernels.tensor import (
|
from triton_kernels.tensor import (
|
||||||
BIT,
|
BIT,
|
||||||
Bitmatrix,
|
Bitmatrix,
|
||||||
SparseMatrix,
|
|
||||||
make_ragged_tensor_metadata,
|
|
||||||
)
|
)
|
||||||
from triton_kernels.topk import topk
|
from triton_kernels.topk import topk
|
||||||
|
|
||||||
|
try:
|
||||||
|
from triton_kernels.tensor import (
|
||||||
|
SparseMatrix,
|
||||||
|
make_ragged_tensor_metadata,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
logger.warning_once("Using legacy triton_kernels on ROCm")
|
||||||
|
use_legacy_triton_kernels = True
|
||||||
|
else:
|
||||||
|
raise
|
||||||
except (AttributeError, ImportError) as e:
|
except (AttributeError, ImportError) as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Failed to import Triton kernels. Please make sure your triton "
|
"Failed to import Triton kernels. Please make sure your triton "
|
||||||
@@ -101,6 +114,12 @@ def legacy_routing_from_bitmatrix(
|
|||||||
Replacement for the removed triton_kernels.routing.routing_from_bitmatrix.
|
Replacement for the removed triton_kernels.routing.routing_from_bitmatrix.
|
||||||
Creates routing data from a bitmatrix representation.
|
Creates routing data from a bitmatrix representation.
|
||||||
"""
|
"""
|
||||||
|
if use_legacy_triton_kernels:
|
||||||
|
from triton_kernels.routing import routing_from_bitmatrix
|
||||||
|
|
||||||
|
return routing_from_bitmatrix(
|
||||||
|
bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act
|
||||||
|
)
|
||||||
sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix)
|
sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix)
|
||||||
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
|
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
|
||||||
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
|
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
|
||||||
@@ -130,6 +149,10 @@ def legacy_routing(
|
|||||||
Replacement for the removed triton_kernels.routing.routing function.
|
Replacement for the removed triton_kernels.routing.routing function.
|
||||||
Computes routing data from gating logits.
|
Computes routing data from gating logits.
|
||||||
"""
|
"""
|
||||||
|
if use_legacy_triton_kernels:
|
||||||
|
from triton_kernels.routing import routing
|
||||||
|
|
||||||
|
return routing(logits, n_expts_act, sm_first=sm_first)
|
||||||
if sm_first:
|
if sm_first:
|
||||||
logits = torch.softmax(logits, dim=-1)
|
logits = torch.softmax(logits, dim=-1)
|
||||||
sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first)
|
sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first)
|
||||||
@@ -231,11 +254,22 @@ def triton_kernel_fused_experts(
|
|||||||
)
|
)
|
||||||
output_tensor = _resize_cache(output_tensor, (batch_dim, M, K))
|
output_tensor = _resize_cache(output_tensor, (batch_dim, M, K))
|
||||||
|
|
||||||
act = FusedActivation(
|
act = (
|
||||||
FnSpecs(
|
FusedActivation(
|
||||||
"swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit"), reduction_n=2
|
FnSpecs(
|
||||||
),
|
"swiglu",
|
||||||
(swiglu_alpha, swiglu_limit),
|
triton_kernels.swiglu.swiglu_fn,
|
||||||
|
("alpha", "limit"),
|
||||||
|
reduction_n=2,
|
||||||
|
),
|
||||||
|
(swiglu_alpha, swiglu_limit),
|
||||||
|
)
|
||||||
|
if not use_legacy_triton_kernels
|
||||||
|
else FusedActivation(
|
||||||
|
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
|
||||||
|
(swiglu_alpha, swiglu_limit),
|
||||||
|
2,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
gammas = routing_data.gate_scal if routing_data else None
|
gammas = routing_data.gate_scal if routing_data else None
|
||||||
|
|
||||||
@@ -296,8 +330,17 @@ def make_routing_data(
|
|||||||
|
|
||||||
bitmatrix_shape = [n_rows, bm_cols * 32]
|
bitmatrix_shape = [n_rows, bm_cols * 32]
|
||||||
bitmatrix_shape_max = [n_rows, None]
|
bitmatrix_shape_max = [n_rows, None]
|
||||||
bitmatrix = Bitmatrix(
|
bitmatrix = (
|
||||||
bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max
|
Bitmatrix(
|
||||||
|
bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max
|
||||||
|
)
|
||||||
|
if not use_legacy_triton_kernels
|
||||||
|
else Bitmatrix(
|
||||||
|
bitmatrix,
|
||||||
|
shape=bitmatrix_shape,
|
||||||
|
shape_max=bitmatrix_shape_max,
|
||||||
|
scratchpad=None,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# matmul_ogs expects invalid topk_weights to be -1s
|
# matmul_ogs expects invalid topk_weights to be -1s
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ class Fp8MoeBackend(Enum):
|
|||||||
AITER = "AITER"
|
AITER = "AITER"
|
||||||
VLLM_CUTLASS = "VLLM_CUTLASS"
|
VLLM_CUTLASS = "VLLM_CUTLASS"
|
||||||
BATCHED_VLLM_CUTLASS = "BATCHED_VLLM_CUTLASS"
|
BATCHED_VLLM_CUTLASS = "BATCHED_VLLM_CUTLASS"
|
||||||
|
XPU = "XPU"
|
||||||
|
|
||||||
|
|
||||||
def backend_to_kernel_cls(
|
def backend_to_kernel_cls(
|
||||||
@@ -123,6 +124,13 @@ def backend_to_kernel_cls(
|
|||||||
|
|
||||||
return CutlassBatchedExpertsFp8
|
return CutlassBatchedExpertsFp8
|
||||||
|
|
||||||
|
elif backend == Fp8MoeBackend.XPU:
|
||||||
|
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import (
|
||||||
|
XPUExpertsFp8,
|
||||||
|
)
|
||||||
|
|
||||||
|
return XPUExpertsFp8
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown FP8 MoE backend: {backend.value}")
|
raise ValueError(f"Unknown FP8 MoE backend: {backend.value}")
|
||||||
|
|
||||||
@@ -154,6 +162,7 @@ def select_fp8_moe_backend(
|
|||||||
Fp8MoeBackend.TRITON,
|
Fp8MoeBackend.TRITON,
|
||||||
Fp8MoeBackend.BATCHED_TRITON,
|
Fp8MoeBackend.BATCHED_TRITON,
|
||||||
Fp8MoeBackend.MARLIN,
|
Fp8MoeBackend.MARLIN,
|
||||||
|
Fp8MoeBackend.XPU,
|
||||||
]
|
]
|
||||||
|
|
||||||
# NOTE(rob): We need to peak into the P/F selection to determine
|
# NOTE(rob): We need to peak into the P/F selection to determine
|
||||||
@@ -393,6 +402,7 @@ def convert_to_fp8_moe_kernel_format(
|
|||||||
Fp8MoeBackend.BATCHED_TRITON,
|
Fp8MoeBackend.BATCHED_TRITON,
|
||||||
Fp8MoeBackend.VLLM_CUTLASS,
|
Fp8MoeBackend.VLLM_CUTLASS,
|
||||||
Fp8MoeBackend.BATCHED_VLLM_CUTLASS,
|
Fp8MoeBackend.BATCHED_VLLM_CUTLASS,
|
||||||
|
Fp8MoeBackend.XPU,
|
||||||
]:
|
]:
|
||||||
raise ValueError(f"Unsupported FP8 MoE backend: {fp8_backend.value}")
|
raise ValueError(f"Unsupported FP8 MoE backend: {fp8_backend.value}")
|
||||||
|
|
||||||
|
|||||||
@@ -4,13 +4,16 @@ import torch
|
|||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEConfig,
|
||||||
FusedMoEParallelConfig,
|
FusedMoEParallelConfig,
|
||||||
|
FusedMoEQuantConfig,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
TopKWeightAndReduceNoOP,
|
TopKWeightAndReduceNoOP,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
QuantKey,
|
QuantKey,
|
||||||
|
kFp8DynamicTensorSym,
|
||||||
kFp8StaticTensorSym,
|
kFp8StaticTensorSym,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@@ -20,6 +23,21 @@ if current_platform.is_xpu():
|
|||||||
|
|
||||||
|
|
||||||
class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
moe_config: FusedMoEConfig,
|
||||||
|
quant_config: FusedMoEQuantConfig,
|
||||||
|
max_num_tokens: int | None = None,
|
||||||
|
num_dispatchers: int | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
moe_config,
|
||||||
|
quant_config,
|
||||||
|
max_num_tokens,
|
||||||
|
num_dispatchers,
|
||||||
|
)
|
||||||
|
self.is_fp8 = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def expects_unquantized_inputs(self) -> bool:
|
def expects_unquantized_inputs(self) -> bool:
|
||||||
return True
|
return True
|
||||||
@@ -49,10 +67,10 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
weight_key: QuantKey | None,
|
weight_key: QuantKey | None,
|
||||||
activation_key: QuantKey | None,
|
activation_key: QuantKey | None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
# TODO: dispatch based on device.
|
|
||||||
SUPPORTED_W_A = [
|
SUPPORTED_W_A = [
|
||||||
(None, None),
|
(None, None),
|
||||||
(kFp8StaticTensorSym, None),
|
(kFp8StaticTensorSym, None),
|
||||||
|
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
|
||||||
]
|
]
|
||||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||||
|
|
||||||
@@ -103,10 +121,10 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
xpu_fused_moe(
|
xpu_fused_moe(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
w13=w1,
|
w13=w1,
|
||||||
w13_scales=a1q_scale,
|
w13_scales=self.w1_scale,
|
||||||
w13_bias=self.w1_bias,
|
w13_bias=self.w1_bias,
|
||||||
w2=w2,
|
w2=w2,
|
||||||
w2_scales=a2_scale,
|
w2_scales=self.w2_scale,
|
||||||
w2_bias=self.w2_bias,
|
w2_bias=self.w2_bias,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
@@ -116,5 +134,22 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
ep_rank=self.moe_config.ep_rank,
|
ep_rank=self.moe_config.ep_rank,
|
||||||
ep_size=self.moe_config.ep_size,
|
ep_size=self.moe_config.ep_size,
|
||||||
output=output,
|
output=output,
|
||||||
|
is_fp8=self.is_fp8,
|
||||||
)
|
)
|
||||||
return
|
|
||||||
|
|
||||||
|
class XPUExpertsFp8(XPUExperts):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
moe_config: FusedMoEConfig,
|
||||||
|
quant_config: FusedMoEQuantConfig,
|
||||||
|
max_num_tokens: int | None = None,
|
||||||
|
num_dispatchers: int | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
moe_config,
|
||||||
|
quant_config,
|
||||||
|
max_num_tokens,
|
||||||
|
num_dispatchers,
|
||||||
|
)
|
||||||
|
self.is_fp8 = True
|
||||||
|
|||||||
@@ -180,18 +180,9 @@ class Fp8Config(QuantizationConfig):
|
|||||||
weight_block_size=weight_block_size,
|
weight_block_size=weight_block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_xpu_quant_method(
|
|
||||||
self, layer: torch.nn.Module, prefix: str
|
|
||||||
) -> "QuantizeMethodBase | None":
|
|
||||||
raise NotImplementedError(
|
|
||||||
"FP8 quantization is not supported during xpu kernel migration."
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> "QuantizeMethodBase | None":
|
) -> "QuantizeMethodBase | None":
|
||||||
if current_platform.is_xpu():
|
|
||||||
return self.get_xpu_quant_method(layer, prefix)
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
if is_layer_skipped(
|
if is_layer_skipped(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
@@ -300,7 +291,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
||||||
)
|
)
|
||||||
# Disable marlin for rocm
|
# Disable marlin for rocm
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm() or current_platform.is_xpu():
|
||||||
self.use_marlin = False
|
self.use_marlin = False
|
||||||
if vllm_is_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
self.use_marlin = False
|
self.use_marlin = False
|
||||||
|
|||||||
@@ -39,6 +39,9 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
|
|||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
||||||
TritonInt8ScaledMMLinearKernel,
|
TritonInt8ScaledMMLinearKernel,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xpu import (
|
||||||
|
XPUFP8ScaledMMLinearKernel,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||||
from vllm.platforms import PlatformEnum, current_platform
|
from vllm.platforms import PlatformEnum, current_platform
|
||||||
|
|
||||||
@@ -72,6 +75,9 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
|
|||||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||||
],
|
],
|
||||||
|
PlatformEnum.XPU: [
|
||||||
|
XPUFP8ScaledMMLinearKernel,
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
|
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
|
||||||
|
|||||||
@@ -0,0 +1,59 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||||
|
FP8ScaledMMLinearKernel,
|
||||||
|
FP8ScaledMMLinearLayerConfig,
|
||||||
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
class XPUFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||||
|
@classmethod
|
||||||
|
def is_supported(
|
||||||
|
cls, compute_capability: int | None = None
|
||||||
|
) -> tuple[bool, str | None]:
|
||||||
|
if not current_platform.is_xpu():
|
||||||
|
return False, "XPUFP8ScaledMM only support on XPU"
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
|
if c.weight_quant_key.dtype not in {torch.float8_e5m2, torch.float8_e4m3fn}:
|
||||||
|
return False, "XPUFP8ScaledMM only support FP8 weight dtype"
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
|
||||||
|
) -> None:
|
||||||
|
assert self.can_implement(c)[0]
|
||||||
|
assert self.is_supported()[0]
|
||||||
|
self.config = c
|
||||||
|
self.layer_param_names = layer_param_names
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
weight = layer.weight
|
||||||
|
weight_scale = layer.weight_scale
|
||||||
|
return torch.ops._xpu_C.fp8_gemm_w8a16(x, weight, weight_scale, bias)
|
||||||
|
|
||||||
|
def apply_scaled_mm(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
bias: torch.Tensor | None,
|
||||||
|
output_shape: list,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
pass
|
||||||
@@ -160,7 +160,7 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
|
|||||||
logger.info_once("Using Triton backend")
|
logger.info_once("Using Triton backend")
|
||||||
return Mxfp4Backend.TRITON
|
return Mxfp4Backend.TRITON
|
||||||
elif current_platform.is_xpu():
|
elif current_platform.is_xpu():
|
||||||
logger.info_once("Using ipex marlin backend on XPU")
|
logger.info_once("Using xpu backend on XPU")
|
||||||
return Mxfp4Backend.MARLIN
|
return Mxfp4Backend.MARLIN
|
||||||
elif current_platform.is_rocm() and has_triton_kernels():
|
elif current_platform.is_rocm() and has_triton_kernels():
|
||||||
logger.info_once("Using Triton backend")
|
logger.info_once("Using Triton backend")
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from vllm.v1.worker.workspace import current_workspace_manager
|
|||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
elif current_platform.is_xpu():
|
elif current_platform.is_xpu():
|
||||||
from vllm._ipex_ops import ipex_ops as ops
|
from vllm._xpu_ops import xpu_ops as ops
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -338,7 +338,6 @@ class CpuPlatform(Platform):
|
|||||||
ld_preload_str += pytorch_libgomp_so
|
ld_preload_str += pytorch_libgomp_so
|
||||||
os.environ["LD_PRELOAD"] = ld_preload_str
|
os.environ["LD_PRELOAD"] = ld_preload_str
|
||||||
|
|
||||||
# To hint IPEX uses shared memory based AllReduce
|
|
||||||
os.environ["LOCAL_WORLD_SIZE"] = str(
|
os.environ["LOCAL_WORLD_SIZE"] = str(
|
||||||
vllm_config.parallel_config.tensor_parallel_size
|
vllm_config.parallel_config.tensor_parallel_size
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -23,12 +23,11 @@ if current_platform.is_cuda():
|
|||||||
|
|
||||||
elif current_platform.is_xpu():
|
elif current_platform.is_xpu():
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm._xpu_ops import xpu_ops
|
||||||
|
|
||||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||||
from vllm._ipex_ops import ipex_ops
|
flash_attn_varlen_func = xpu_ops.flash_attn_varlen_func # type: ignore[assignment]
|
||||||
|
get_scheduler_metadata = xpu_ops.get_scheduler_metadata # type: ignore[assignment]
|
||||||
flash_attn_varlen_func = ipex_ops.flash_attn_varlen_func # type: ignore[assignment]
|
|
||||||
get_scheduler_metadata = ipex_ops.get_scheduler_metadata # type: ignore[assignment]
|
|
||||||
elif current_platform.is_rocm():
|
elif current_platform.is_rocm():
|
||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
|
from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
|
||||||
@@ -153,7 +152,7 @@ def is_flash_attn_varlen_func_available() -> bool:
|
|||||||
|
|
||||||
Platform-specific sources:
|
Platform-specific sources:
|
||||||
- CUDA: vllm.vllm_flash_attn.flash_attn_varlen_func
|
- CUDA: vllm.vllm_flash_attn.flash_attn_varlen_func
|
||||||
- XPU: ipex_ops.flash_attn_varlen_func
|
- XPU: xpu_ops.flash_attn_varlen_func
|
||||||
- ROCm: upstream flash_attn.flash_attn_varlen_func (if available)
|
- ROCm: upstream flash_attn.flash_attn_varlen_func (if available)
|
||||||
|
|
||||||
Note: This is separate from the AITER flash attention backend (rocm_aiter_fa.py)
|
Note: This is separate from the AITER flash attention backend (rocm_aiter_fa.py)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from vllm.platforms import current_platform
|
|||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
elif current_platform.is_xpu():
|
elif current_platform.is_xpu():
|
||||||
from vllm._ipex_ops import ipex_ops as ops # type: ignore[no-redef]
|
from vllm._xpu_ops import xpu_ops as ops # type: ignore[no-redef]
|
||||||
|
|
||||||
|
|
||||||
class PagedAttention:
|
class PagedAttention:
|
||||||
|
|||||||
@@ -982,10 +982,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
session._all_token_ids.extend(update.prompt_token_ids or ())
|
session._all_token_ids.extend(update.prompt_token_ids or ())
|
||||||
session.prompt_token_ids.extend(update.prompt_token_ids or ())
|
session.prompt_token_ids.extend(update.prompt_token_ids or ())
|
||||||
# Update block hashes for the new tokens
|
# Update block hashes for the new tokens.
|
||||||
# (mirrors Request.append_output_token_ids)
|
session.update_block_hashes()
|
||||||
if session.get_hash_new_full_blocks is not None:
|
|
||||||
session.block_hashes.extend(session.get_hash_new_full_blocks())
|
|
||||||
session.num_prompt_tokens = len(session.prompt_token_ids)
|
session.num_prompt_tokens = len(session.prompt_token_ids)
|
||||||
session.arrival_time = update.arrival_time
|
session.arrival_time = update.arrival_time
|
||||||
session.sampling_params = update.sampling_params
|
session.sampling_params = update.sampling_params
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import time
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from collections.abc import Callable, Mapping
|
from collections.abc import Callable, Mapping
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -164,10 +163,11 @@ class Request:
|
|||||||
self.num_external_computed_tokens = 0
|
self.num_external_computed_tokens = 0
|
||||||
|
|
||||||
self.block_hashes: list[BlockHash] = []
|
self.block_hashes: list[BlockHash] = []
|
||||||
self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None
|
# Store the block hasher without binding self to avoid creating a
|
||||||
if block_hasher is not None:
|
# reference cycle (Request -> partial -> Request) that prevents
|
||||||
self.get_hash_new_full_blocks = partial(block_hasher, self)
|
# immediate garbage collection via reference counting.
|
||||||
self.block_hashes = self.get_hash_new_full_blocks()
|
self._block_hasher: Callable[[Request], list[BlockHash]] | None = block_hasher
|
||||||
|
self.update_block_hashes()
|
||||||
|
|
||||||
self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()
|
self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()
|
||||||
|
|
||||||
@@ -212,8 +212,12 @@ class Request:
|
|||||||
self._output_token_ids.extend(token_ids)
|
self._output_token_ids.extend(token_ids)
|
||||||
self._all_token_ids.extend(token_ids)
|
self._all_token_ids.extend(token_ids)
|
||||||
|
|
||||||
if self.get_hash_new_full_blocks is not None:
|
self.update_block_hashes()
|
||||||
self.block_hashes.extend(self.get_hash_new_full_blocks())
|
|
||||||
|
def update_block_hashes(self) -> None:
|
||||||
|
"""Compute block hashes for any new full blocks and append them."""
|
||||||
|
if self._block_hasher is not None:
|
||||||
|
self.block_hashes.extend(self._block_hasher(self))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def use_structured_output(self) -> bool:
|
def use_structured_output(self) -> bool:
|
||||||
|
|||||||
Reference in New Issue
Block a user