Compare commits
169 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
468d761b32 | ||
|
|
e4bf860a54 | ||
|
|
91f50a6fe2 | ||
|
|
79a268c4ab | ||
|
|
eace8bf0b9 | ||
|
|
1e8f4252aa | ||
|
|
2b7949c1c2 | ||
|
|
62b5166bd4 | ||
|
|
d86285a4a4 | ||
|
|
d87f39e9a9 | ||
|
|
d3c8180ac4 | ||
|
|
62b8aebc6f | ||
|
|
050f285ff6 | ||
|
|
8f2ea22bde | ||
|
|
0ae11f78ab | ||
|
|
34128a697e | ||
|
|
c1b4e4157c | ||
|
|
ceaf4ed003 | ||
|
|
ad8d696a99 | ||
|
|
3d925165f2 | ||
|
|
1543680691 | ||
|
|
077f0a2e8a | ||
|
|
e73ed0f1c6 | ||
|
|
296cdf8ac7 | ||
|
|
747b1a7147 | ||
|
|
95e5b087cf | ||
|
|
a37d815b83 | ||
|
|
7f2593b164 | ||
|
|
fe7d648fe5 | ||
|
|
cc74b2b232 | ||
|
|
91528575ec | ||
|
|
a22cdea371 | ||
|
|
682789d402 | ||
|
|
138485a82d | ||
|
|
bc9df1571b | ||
|
|
15b86408a8 | ||
|
|
7be4f5628f | ||
|
|
8f20fc04bf | ||
|
|
221d93ecbf | ||
|
|
d17c8477f1 | ||
|
|
a134ef6f5e | ||
|
|
8a7a3e4436 | ||
|
|
8f9c28fd40 | ||
|
|
cd2f63fb36 | ||
|
|
87fa80c91f | ||
|
|
e1bb2fd52d | ||
|
|
705578ae14 | ||
|
|
e8cc7967ff | ||
|
|
53b018edcb | ||
|
|
66ded03067 | ||
|
|
6dc1fc9cfe | ||
|
|
533d2a1f39 | ||
|
|
a53222544c | ||
|
|
fe3b5bbc23 | ||
|
|
8438e0569e | ||
|
|
11d652bd4f | ||
|
|
d150e4f89f | ||
|
|
e95cd87959 | ||
|
|
69e1d2fb69 | ||
|
|
05434764cd | ||
|
|
4e7ee664e2 | ||
|
|
37e84a403d | ||
|
|
4695397dcf | ||
|
|
d619ae2d19 | ||
|
|
eb46fbfda2 | ||
|
|
0003e9154b | ||
|
|
e11e200736 | ||
|
|
8db1bf32f8 | ||
|
|
aceb17cf2d | ||
|
|
563c54f760 | ||
|
|
2cd6b4f362 | ||
|
|
711a000255 | ||
|
|
989ae2538d | ||
|
|
0a430b4ae2 | ||
|
|
ec8e3c695f | ||
|
|
98afde19fc | ||
|
|
5c2e66e487 | ||
|
|
546e721168 | ||
|
|
b8aacac31a | ||
|
|
d04973ad54 | ||
|
|
fbb9d9eef4 | ||
|
|
09473ee41c | ||
|
|
d4ec9ffb95 | ||
|
|
96b6a6d790 | ||
|
|
36729bac13 | ||
|
|
7fd3949a0b | ||
|
|
1096717ae9 | ||
|
|
c2b4a1bce9 | ||
|
|
e46a60aa4c | ||
|
|
1e96c3341a | ||
|
|
95e7d4a97c | ||
|
|
559eb852f8 | ||
|
|
a10d3056da | ||
|
|
8afca50889 | ||
|
|
08ccee1e83 | ||
|
|
c1dc547129 | ||
|
|
f3d0bf7589 | ||
|
|
e9da5a40c6 | ||
|
|
e42df7227d | ||
|
|
caada5e50a | ||
|
|
67b4221a61 | ||
|
|
63e7176f26 | ||
|
|
934d3662f7 | ||
|
|
92cd2e2f21 | ||
|
|
e4c4072c94 | ||
|
|
e35397468f | ||
|
|
8b317c6dd0 | ||
|
|
bd3c144e0b | ||
|
|
0258b7a94b | ||
|
|
b3104b2a10 | ||
|
|
c2e00af523 | ||
|
|
c013d32c75 | ||
|
|
11dd6ebb89 | ||
|
|
6c0b04515f | ||
|
|
e23a43aef8 | ||
|
|
e7c7067b45 | ||
|
|
6d592eb430 | ||
|
|
d036198e23 | ||
|
|
59a6abf3c9 | ||
|
|
bc0c0192d1 | ||
|
|
f46864d68d | ||
|
|
b4543c8f6b | ||
|
|
0ce0539d47 | ||
|
|
2f19283549 | ||
|
|
95baec828f | ||
|
|
e4be7d70bb | ||
|
|
54951ac4bf | ||
|
|
18de883489 | ||
|
|
1d7c940d74 | ||
|
|
cfaf49a167 | ||
|
|
9edec652e2 | ||
|
|
e0dd4d3589 | ||
|
|
e5043a3e75 | ||
|
|
d03d64fd2e | ||
|
|
78107fa091 | ||
|
|
c391e4b68e | ||
|
|
9117f892f0 | ||
|
|
db2a6a41e2 | ||
|
|
ca81ff5196 | ||
|
|
b7782002e1 | ||
|
|
819a309c0f | ||
|
|
aabe8f40f2 | ||
|
|
498eb5cfa3 | ||
|
|
537ee25f43 | ||
|
|
294f8f6665 | ||
|
|
b95047f2da | ||
|
|
2ff767b513 | ||
|
|
3dcb3e8b98 | ||
|
|
c64cf38673 | ||
|
|
76b889bf1d | ||
|
|
c9b506dad4 | ||
|
|
5757d90e26 | ||
|
|
a3c226e7eb | ||
|
|
b321d4881b | ||
|
|
ad6eca408b | ||
|
|
205b94942e | ||
|
|
3bec41f41a | ||
|
|
0739b1947f | ||
|
|
77a6572aa5 | ||
|
|
0e3f06fe9c | ||
|
|
eb69d68804 | ||
|
|
7d4e1b85e7 | ||
|
|
93deb0b38f | ||
|
|
ccb58b23e6 | ||
|
|
49782fcb76 | ||
|
|
f03cc667a0 | ||
|
|
563c1d7ec5 | ||
|
|
9c82a1bec3 | ||
|
|
b6d103542c |
14
.buildkite/run-cpu-test.sh
Normal file
14
.buildkite/run-cpu-test.sh
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
# This script build the CPU docker image and run the offline inference inside the container.
|
||||||
|
# It serves a sanity check for compilation and basic model usage.
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
# Try building the docker image
|
||||||
|
docker build -t cpu-test -f Dockerfile.cpu .
|
||||||
|
|
||||||
|
# Setup cleanup
|
||||||
|
remove_docker_container() { docker rm -f cpu-test || true; }
|
||||||
|
trap remove_docker_container EXIT
|
||||||
|
remove_docker_container
|
||||||
|
|
||||||
|
# Run the image and launch offline inference
|
||||||
|
docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 examples/offline_inference.py
|
||||||
37
.buildkite/run-neuron-test.sh
Normal file
37
.buildkite/run-neuron-test.sh
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
# This script build the Neuron docker image and run the API server inside the container.
|
||||||
|
# It serves a sanity check for compilation and basic model usage.
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Try building the docker image
|
||||||
|
aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com
|
||||||
|
docker build -t neuron -f Dockerfile.neuron .
|
||||||
|
|
||||||
|
# Setup cleanup
|
||||||
|
remove_docker_container() { docker rm -f neuron || true; }
|
||||||
|
trap remove_docker_container EXIT
|
||||||
|
remove_docker_container
|
||||||
|
|
||||||
|
# Run the image
|
||||||
|
docker run --device=/dev/neuron0 --device=/dev/neuron1 --network host --name neuron neuron python3 -m vllm.entrypoints.api_server \
|
||||||
|
--model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --max-num-seqs 8 --max-model-len 128 --block-size 128 --device neuron --tensor-parallel-size 2 &
|
||||||
|
|
||||||
|
# Wait for the server to start
|
||||||
|
wait_for_server_to_start() {
|
||||||
|
timeout=300
|
||||||
|
counter=0
|
||||||
|
|
||||||
|
while [ "$(curl -s -o /dev/null -w ''%{http_code}'' localhost:8000/health)" != "200" ]; do
|
||||||
|
sleep 1
|
||||||
|
counter=$((counter + 1))
|
||||||
|
if [ $counter -ge $timeout ]; then
|
||||||
|
echo "Timeout after $timeout seconds"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
}
|
||||||
|
wait_for_server_to_start
|
||||||
|
|
||||||
|
# Test a simple prompt
|
||||||
|
curl -X POST -H "Content-Type: application/json" \
|
||||||
|
localhost:8000/generate \
|
||||||
|
-d '{"prompt": "San Francisco is a"}'
|
||||||
@@ -12,7 +12,11 @@ steps:
|
|||||||
command: pytest -v -s async_engine
|
command: pytest -v -s async_engine
|
||||||
|
|
||||||
- label: Basic Correctness Test
|
- label: Basic Correctness Test
|
||||||
command: pytest -v -s basic_correctness
|
commands:
|
||||||
|
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
|
||||||
|
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
|
||||||
|
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||||
|
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||||
|
|
||||||
- label: Core Test
|
- label: Core Test
|
||||||
command: pytest -v -s core
|
command: pytest -v -s core
|
||||||
@@ -27,14 +31,20 @@ steps:
|
|||||||
num_gpus: 2 # only support 1 or 2 for now.
|
num_gpus: 2 # only support 1 or 2 for now.
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s test_pynccl.py
|
- pytest -v -s test_pynccl.py
|
||||||
|
- pytest -v -s test_pynccl_library.py
|
||||||
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
|
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
|
||||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
|
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
|
||||||
|
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
|
||||||
|
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py
|
||||||
|
|
||||||
- label: Engine Test
|
- label: Engine Test
|
||||||
command: pytest -v -s engine tokenization test_sequence.py test_config.py
|
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
|
||||||
|
|
||||||
- label: Entrypoints Test
|
- label: Entrypoints Test
|
||||||
command: pytest -v -s entrypoints
|
commands:
|
||||||
|
# these tests have to be separated, because each one will allocate all posible GPU memory
|
||||||
|
- pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py
|
||||||
|
- pytest -v -s entrypoints/test_server_oot_registration.py
|
||||||
|
|
||||||
- label: Examples Test
|
- label: Examples Test
|
||||||
working_dir: "/vllm-workspace/examples"
|
working_dir: "/vllm-workspace/examples"
|
||||||
@@ -80,9 +90,15 @@ steps:
|
|||||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||||
parallelism: 4
|
parallelism: 4
|
||||||
|
|
||||||
|
- label: Tensorizer Test
|
||||||
|
command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader
|
||||||
|
|
||||||
- label: Metrics Test
|
- label: Metrics Test
|
||||||
command: pytest -v -s metrics
|
command: pytest -v -s metrics
|
||||||
|
|
||||||
|
- label: Quantization Test
|
||||||
|
command: pytest -v -s quantization
|
||||||
|
|
||||||
- label: Benchmarks
|
- label: Benchmarks
|
||||||
working_dir: "/vllm-workspace/.buildkite"
|
working_dir: "/vllm-workspace/.buildkite"
|
||||||
commands:
|
commands:
|
||||||
@@ -90,7 +106,7 @@ steps:
|
|||||||
- bash run-benchmarks.sh
|
- bash run-benchmarks.sh
|
||||||
|
|
||||||
- label: Documentation Build
|
- label: Documentation Build
|
||||||
working_dir: "/vllm-workspace/docs"
|
working_dir: "/vllm-workspace/test_docs/docs"
|
||||||
no_gpu: True
|
no_gpu: True
|
||||||
commands:
|
commands:
|
||||||
- pip install -r requirements-docs.txt
|
- pip install -r requirements-docs.txt
|
||||||
|
|||||||
@@ -3,10 +3,6 @@
|
|||||||
{% set default_working_dir = "/vllm-workspace/tests" %}
|
{% set default_working_dir = "/vllm-workspace/tests" %}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- label: "AMD Test"
|
|
||||||
agents:
|
|
||||||
queue: amd
|
|
||||||
command: bash .buildkite/run-amd-test.sh
|
|
||||||
|
|
||||||
- label: ":docker: build image"
|
- label: ":docker: build image"
|
||||||
commands:
|
commands:
|
||||||
@@ -20,6 +16,19 @@ steps:
|
|||||||
limit: 5
|
limit: 5
|
||||||
- wait
|
- wait
|
||||||
|
|
||||||
|
- label: "AMD Test"
|
||||||
|
agents:
|
||||||
|
queue: amd
|
||||||
|
command: bash .buildkite/run-amd-test.sh
|
||||||
|
|
||||||
|
- label: "Neuron Test"
|
||||||
|
agents:
|
||||||
|
queue: neuron
|
||||||
|
command: bash .buildkite/run-neuron-test.sh
|
||||||
|
|
||||||
|
- label: "CPU Test"
|
||||||
|
command: bash .buildkite/run-cpu-test.sh
|
||||||
|
|
||||||
{% for step in steps %}
|
{% for step in steps %}
|
||||||
- label: "{{ step.label }}"
|
- label: "{{ step.label }}"
|
||||||
agents:
|
agents:
|
||||||
|
|||||||
1
.github/ISSUE_TEMPLATE/200-installation.yml
vendored
1
.github/ISSUE_TEMPLATE/200-installation.yml
vendored
@@ -18,6 +18,7 @@ body:
|
|||||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||||
python collect_env.py
|
python collect_env.py
|
||||||
```
|
```
|
||||||
|
It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
|
||||||
value: |
|
value: |
|
||||||
```text
|
```text
|
||||||
The output of `python collect_env.py`
|
The output of `python collect_env.py`
|
||||||
|
|||||||
1
.github/ISSUE_TEMPLATE/300-usage.yml
vendored
1
.github/ISSUE_TEMPLATE/300-usage.yml
vendored
@@ -18,6 +18,7 @@ body:
|
|||||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||||
python collect_env.py
|
python collect_env.py
|
||||||
```
|
```
|
||||||
|
It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
|
||||||
value: |
|
value: |
|
||||||
```text
|
```text
|
||||||
The output of `python collect_env.py`
|
The output of `python collect_env.py`
|
||||||
|
|||||||
3
.github/ISSUE_TEMPLATE/400-bug report.yml
vendored
3
.github/ISSUE_TEMPLATE/400-bug report.yml
vendored
@@ -18,6 +18,7 @@ body:
|
|||||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||||
python collect_env.py
|
python collect_env.py
|
||||||
```
|
```
|
||||||
|
It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
|
||||||
value: |
|
value: |
|
||||||
```text
|
```text
|
||||||
The output of `python collect_env.py`
|
The output of `python collect_env.py`
|
||||||
@@ -57,6 +58,8 @@ body:
|
|||||||
If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com.
|
If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com.
|
||||||
|
|
||||||
Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
|
Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
|
||||||
|
|
||||||
|
If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs.
|
||||||
placeholder: |
|
placeholder: |
|
||||||
A clear and concise description of what the bug is.
|
A clear and concise description of what the bug is.
|
||||||
|
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ body:
|
|||||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||||
python collect_env.py
|
python collect_env.py
|
||||||
```
|
```
|
||||||
|
It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
|
||||||
value: |
|
value: |
|
||||||
```text
|
```text
|
||||||
The output of `python collect_env.py`
|
The output of `python collect_env.py`
|
||||||
|
|||||||
51
.github/workflows/mypy.yaml
vendored
Normal file
51
.github/workflows/mypy.yaml
vendored
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
name: mypy
|
||||||
|
|
||||||
|
on:
|
||||||
|
# Trigger the workflow on push or pull request,
|
||||||
|
# but only for the main branch
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
ruff:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install mypy==1.9.0
|
||||||
|
pip install types-setuptools
|
||||||
|
pip install types-PyYAML
|
||||||
|
pip install types-requests
|
||||||
|
pip install types-setuptools
|
||||||
|
- name: Mypy
|
||||||
|
run: |
|
||||||
|
mypy vllm/attention --config-file pyproject.toml
|
||||||
|
# TODO(sang): Fix nested dir
|
||||||
|
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
mypy vllm/distributed --config-file pyproject.toml
|
||||||
|
mypy vllm/entrypoints --config-file pyproject.toml
|
||||||
|
mypy vllm/executor --config-file pyproject.toml
|
||||||
|
mypy vllm/usage --config-file pyproject.toml
|
||||||
|
mypy vllm/*.py --config-file pyproject.toml
|
||||||
|
mypy vllm/transformers_utils --config-file pyproject.toml
|
||||||
|
mypy vllm/engine --config-file pyproject.toml
|
||||||
|
mypy vllm/worker --config-file pyproject.toml
|
||||||
|
mypy vllm/spec_decode --config-file pyproject.toml
|
||||||
|
# TODO(sang): Fix nested dir
|
||||||
|
mypy vllm/model_executor/*.py --config-file pyproject.toml
|
||||||
|
# TODO(sang): Fix nested dir
|
||||||
|
# mypy vllm/lora/*.py --config-file pyproject.toml
|
||||||
|
|
||||||
5
.github/workflows/publish.yml
vendored
5
.github/workflows/publish.yml
vendored
@@ -49,13 +49,16 @@ jobs:
|
|||||||
matrix:
|
matrix:
|
||||||
os: ['ubuntu-20.04']
|
os: ['ubuntu-20.04']
|
||||||
python-version: ['3.8', '3.9', '3.10', '3.11']
|
python-version: ['3.8', '3.9', '3.10', '3.11']
|
||||||
pytorch-version: ['2.1.2'] # Must be the most recent version that meets requirements.txt.
|
pytorch-version: ['2.2.1'] # Must be the most recent version that meets requirements-cuda.txt.
|
||||||
cuda-version: ['11.8', '12.1']
|
cuda-version: ['11.8', '12.1']
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Setup ccache
|
||||||
|
uses: hendrikmuhs/ccache-action@v1.2
|
||||||
|
|
||||||
- name: Set up Linux Env
|
- name: Set up Linux Env
|
||||||
if: ${{ runner.os == 'Linux' }}
|
if: ${{ runner.os == 'Linux' }}
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
2
.github/workflows/ruff.yml
vendored
2
.github/workflows/ruff.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.10"]
|
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
|||||||
5
.github/workflows/scripts/build.sh
vendored
5
.github/workflows/scripts/build.sh
vendored
@@ -9,12 +9,13 @@ LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
|
|||||||
|
|
||||||
# Install requirements
|
# Install requirements
|
||||||
$python_executable -m pip install wheel packaging
|
$python_executable -m pip install wheel packaging
|
||||||
$python_executable -m pip install -r requirements.txt
|
$python_executable -m pip install -r requirements-cuda.txt
|
||||||
|
|
||||||
# Limit the number of parallel jobs to avoid OOM
|
# Limit the number of parallel jobs to avoid OOM
|
||||||
export MAX_JOBS=1
|
export MAX_JOBS=1
|
||||||
# Make sure punica is built for the release (for LoRA)
|
# Make sure punica is built for the release (for LoRA)
|
||||||
export VLLM_INSTALL_PUNICA_KERNELS=1
|
export VLLM_INSTALL_PUNICA_KERNELS=1
|
||||||
|
# Make sure release wheels are built for the following architectures
|
||||||
|
export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
|
||||||
# Build
|
# Build
|
||||||
$python_executable setup.py bdist_wheel --dist-dir=dist
|
$python_executable setup.py bdist_wheel --dist-dir=dist
|
||||||
|
|||||||
2
.github/workflows/yapf.yml
vendored
2
.github/workflows/yapf.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.10"]
|
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -70,6 +70,8 @@ instance/
|
|||||||
|
|
||||||
# Sphinx documentation
|
# Sphinx documentation
|
||||||
docs/_build/
|
docs/_build/
|
||||||
|
docs/source/getting_started/examples/*.rst
|
||||||
|
!**/*.template.rst
|
||||||
|
|
||||||
# PyBuilder
|
# PyBuilder
|
||||||
.pybuilder/
|
.pybuilder/
|
||||||
@@ -181,6 +183,7 @@ _build/
|
|||||||
# hip files generated by PyTorch
|
# hip files generated by PyTorch
|
||||||
*.hip
|
*.hip
|
||||||
*_hip*
|
*_hip*
|
||||||
|
hip_compat.h
|
||||||
|
|
||||||
# Benchmark dataset
|
# Benchmark dataset
|
||||||
*.json
|
*.json
|
||||||
|
|||||||
@@ -2,7 +2,10 @@ cmake_minimum_required(VERSION 3.21)
|
|||||||
|
|
||||||
project(vllm_extensions LANGUAGES CXX)
|
project(vllm_extensions LANGUAGES CXX)
|
||||||
|
|
||||||
|
option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "cuda")
|
||||||
|
|
||||||
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
|
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
|
||||||
|
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
|
||||||
|
|
||||||
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
||||||
|
|
||||||
@@ -16,7 +19,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11")
|
|||||||
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")
|
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")
|
||||||
|
|
||||||
# Supported AMD GPU architectures.
|
# Supported AMD GPU architectures.
|
||||||
set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100")
|
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100")
|
||||||
|
|
||||||
#
|
#
|
||||||
# Supported/expected torch versions for CUDA/ROCm.
|
# Supported/expected torch versions for CUDA/ROCm.
|
||||||
@@ -28,7 +31,7 @@ set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100")
|
|||||||
# requirements.txt files and should be kept consistent. The ROCm torch
|
# requirements.txt files and should be kept consistent. The ROCm torch
|
||||||
# versions are derived from Dockerfile.rocm
|
# versions are derived from Dockerfile.rocm
|
||||||
#
|
#
|
||||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.1.2")
|
set(TORCH_SUPPORTED_VERSION_CUDA "2.2.1")
|
||||||
set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1")
|
set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1")
|
||||||
set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1")
|
set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1")
|
||||||
|
|
||||||
@@ -76,6 +79,19 @@ find_package(Torch REQUIRED)
|
|||||||
find_library(torch_python_LIBRARY torch_python PATHS
|
find_library(torch_python_LIBRARY torch_python PATHS
|
||||||
"${TORCH_INSTALL_PREFIX}/lib")
|
"${TORCH_INSTALL_PREFIX}/lib")
|
||||||
|
|
||||||
|
#
|
||||||
|
# Forward the non-CUDA device extensions to external CMake scripts.
|
||||||
|
#
|
||||||
|
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND
|
||||||
|
NOT VLLM_TARGET_DEVICE STREQUAL "rocm")
|
||||||
|
if (VLLM_TARGET_DEVICE STREQUAL "cpu")
|
||||||
|
include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}")
|
||||||
|
endif()
|
||||||
|
return()
|
||||||
|
endif()
|
||||||
|
|
||||||
#
|
#
|
||||||
# Set up GPU language and check the torch version and warn if it isn't
|
# Set up GPU language and check the torch version and warn if it isn't
|
||||||
# what is expected.
|
# what is expected.
|
||||||
@@ -151,12 +167,14 @@ set(VLLM_EXT_SRC
|
|||||||
"csrc/layernorm_kernels.cu"
|
"csrc/layernorm_kernels.cu"
|
||||||
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
|
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
|
||||||
"csrc/quantization/gptq/q_gemm.cu"
|
"csrc/quantization/gptq/q_gemm.cu"
|
||||||
|
"csrc/quantization/fp8/fp8_cuda_kernels.cu"
|
||||||
"csrc/cuda_utils_kernels.cu"
|
"csrc/cuda_utils_kernels.cu"
|
||||||
"csrc/moe_align_block_size_kernels.cu"
|
"csrc/moe_align_block_size_kernels.cu"
|
||||||
"csrc/pybind.cpp")
|
"csrc/pybind.cpp")
|
||||||
|
|
||||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
list(APPEND VLLM_EXT_SRC
|
list(APPEND VLLM_EXT_SRC
|
||||||
|
"csrc/quantization/aqlm/gemm_kernels.cu"
|
||||||
"csrc/quantization/awq/gemm_kernels.cu"
|
"csrc/quantization/awq/gemm_kernels.cu"
|
||||||
"csrc/quantization/marlin/marlin_cuda_kernel.cu"
|
"csrc/quantization/marlin/marlin_cuda_kernel.cu"
|
||||||
"csrc/custom_all_reduce.cu")
|
"csrc/custom_all_reduce.cu")
|
||||||
@@ -194,23 +212,11 @@ define_gpu_extension_target(
|
|||||||
|
|
||||||
set(VLLM_PUNICA_EXT_SRC
|
set(VLLM_PUNICA_EXT_SRC
|
||||||
"csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu"
|
"csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu"
|
||||||
"csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu"
|
|
||||||
"csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu"
|
|
||||||
"csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu"
|
|
||||||
"csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu"
|
"csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu"
|
||||||
"csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu"
|
|
||||||
"csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu"
|
|
||||||
"csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu"
|
|
||||||
"csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu"
|
|
||||||
"csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu"
|
"csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu"
|
||||||
"csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu"
|
|
||||||
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
|
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
|
||||||
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
|
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
|
||||||
"csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu"
|
|
||||||
"csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu"
|
|
||||||
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
|
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
|
||||||
"csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu"
|
|
||||||
"csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu"
|
|
||||||
"csrc/punica/punica_ops.cc")
|
"csrc/punica/punica_ops.cc")
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ Express your support on Twitter if vLLM aids you, or simply offer your appreciat
|
|||||||
### Build from source
|
### Build from source
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install -r requirements.txt
|
|
||||||
pip install -e . # This may take several minutes.
|
pip install -e . # This may take several minutes.
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -30,6 +29,8 @@ pip install -e . # This may take several minutes.
|
|||||||
```bash
|
```bash
|
||||||
pip install -r requirements-dev.txt
|
pip install -r requirements-dev.txt
|
||||||
|
|
||||||
|
# linting and formatting
|
||||||
|
bash format.sh
|
||||||
# Static type checking
|
# Static type checking
|
||||||
mypy
|
mypy
|
||||||
# Unit tests
|
# Unit tests
|
||||||
|
|||||||
107
Dockerfile
107
Dockerfile
@@ -2,6 +2,7 @@
|
|||||||
# to run the OpenAI compatible server.
|
# to run the OpenAI compatible server.
|
||||||
|
|
||||||
#################### BASE BUILD IMAGE ####################
|
#################### BASE BUILD IMAGE ####################
|
||||||
|
# prepare basic build environment
|
||||||
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
|
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
|
||||||
|
|
||||||
RUN apt-get update -y \
|
RUN apt-get update -y \
|
||||||
@@ -16,18 +17,26 @@ RUN ldconfig /usr/local/cuda-12.1/compat/
|
|||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
# install build and runtime dependencies
|
# install build and runtime dependencies
|
||||||
COPY requirements.txt requirements.txt
|
COPY requirements-common.txt requirements-common.txt
|
||||||
|
COPY requirements-cuda.txt requirements-cuda.txt
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
pip install -r requirements.txt
|
pip install -r requirements-cuda.txt
|
||||||
|
|
||||||
# install development dependencies
|
# install development dependencies
|
||||||
COPY requirements-dev.txt requirements-dev.txt
|
COPY requirements-dev.txt requirements-dev.txt
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
pip install -r requirements-dev.txt
|
pip install -r requirements-dev.txt
|
||||||
|
|
||||||
|
# cuda arch list used by torch
|
||||||
|
# can be useful for both `dev` and `test`
|
||||||
|
# explicitly set the list to avoid issues with torch 2.2
|
||||||
|
# see https://github.com/pytorch/pytorch/pull/123243
|
||||||
|
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
|
||||||
|
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||||
#################### BASE BUILD IMAGE ####################
|
#################### BASE BUILD IMAGE ####################
|
||||||
|
|
||||||
|
|
||||||
#################### EXTENSION BUILD IMAGE ####################
|
#################### WHEEL BUILD IMAGE ####################
|
||||||
FROM dev AS build
|
FROM dev AS build
|
||||||
|
|
||||||
# install build dependencies
|
# install build dependencies
|
||||||
@@ -38,18 +47,16 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
|||||||
# install compiler cache to speed up compilation leveraging local or remote caching
|
# install compiler cache to speed up compilation leveraging local or remote caching
|
||||||
RUN apt-get update -y && apt-get install -y ccache
|
RUN apt-get update -y && apt-get install -y ccache
|
||||||
|
|
||||||
# copy input files
|
# files and directories related to build wheels
|
||||||
COPY csrc csrc
|
COPY csrc csrc
|
||||||
COPY setup.py setup.py
|
COPY setup.py setup.py
|
||||||
COPY cmake cmake
|
COPY cmake cmake
|
||||||
COPY CMakeLists.txt CMakeLists.txt
|
COPY CMakeLists.txt CMakeLists.txt
|
||||||
COPY requirements.txt requirements.txt
|
COPY requirements-common.txt requirements-common.txt
|
||||||
|
COPY requirements-cuda.txt requirements-cuda.txt
|
||||||
COPY pyproject.toml pyproject.toml
|
COPY pyproject.toml pyproject.toml
|
||||||
COPY vllm/__init__.py vllm/__init__.py
|
COPY vllm vllm
|
||||||
|
|
||||||
# cuda arch list used by torch
|
|
||||||
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
|
|
||||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
|
||||||
# max jobs used by Ninja to build extensions
|
# max jobs used by Ninja to build extensions
|
||||||
ARG max_jobs=2
|
ARG max_jobs=2
|
||||||
ENV MAX_JOBS=${max_jobs}
|
ENV MAX_JOBS=${max_jobs}
|
||||||
@@ -61,7 +68,15 @@ ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
|||||||
|
|
||||||
ENV CCACHE_DIR=/root/.cache/ccache
|
ENV CCACHE_DIR=/root/.cache/ccache
|
||||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||||
python3 setup.py build_ext --inplace
|
--mount=type=cache,target=/root/.cache/pip \
|
||||||
|
python3 setup.py bdist_wheel --dist-dir=dist
|
||||||
|
|
||||||
|
# the `vllm_nccl` package must be installed from source distribution
|
||||||
|
# pip is too smart to store a wheel in the cache, and other CI jobs
|
||||||
|
# will directly use the wheel from the cache, which is not what we want.
|
||||||
|
# we need to remove it manually
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
|
pip cache remove vllm_nccl*
|
||||||
#################### EXTENSION Build IMAGE ####################
|
#################### EXTENSION Build IMAGE ####################
|
||||||
|
|
||||||
#################### FLASH_ATTENTION Build IMAGE ####################
|
#################### FLASH_ATTENTION Build IMAGE ####################
|
||||||
@@ -81,57 +96,59 @@ RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \
|
|||||||
|
|
||||||
#################### FLASH_ATTENTION Build IMAGE ####################
|
#################### FLASH_ATTENTION Build IMAGE ####################
|
||||||
|
|
||||||
|
#################### vLLM installation IMAGE ####################
|
||||||
|
# image with vLLM installed
|
||||||
|
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base
|
||||||
|
WORKDIR /vllm-workspace
|
||||||
|
|
||||||
|
RUN apt-get update -y \
|
||||||
|
&& apt-get install -y python3-pip git vim
|
||||||
|
|
||||||
|
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||||
|
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||||
|
# this won't be needed for future versions of this docker image
|
||||||
|
# or future versions of triton.
|
||||||
|
RUN ldconfig /usr/local/cuda-12.1/compat/
|
||||||
|
|
||||||
|
# install vllm wheel first, so that torch etc will be installed
|
||||||
|
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
|
||||||
|
--mount=type=cache,target=/root/.cache/pip \
|
||||||
|
pip install dist/*.whl --verbose
|
||||||
|
|
||||||
|
RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
|
||||||
|
--mount=type=cache,target=/root/.cache/pip \
|
||||||
|
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
|
||||||
|
#################### vLLM installation IMAGE ####################
|
||||||
|
|
||||||
|
|
||||||
#################### TEST IMAGE ####################
|
#################### TEST IMAGE ####################
|
||||||
# image to run unit testing suite
|
# image to run unit testing suite
|
||||||
FROM dev AS test
|
# note that this uses vllm installed by `pip`
|
||||||
|
FROM vllm-base AS test
|
||||||
|
|
||||||
# copy pytorch extensions separately to avoid having to rebuild
|
|
||||||
# when python code changes
|
|
||||||
WORKDIR /vllm-workspace
|
|
||||||
# ADD is used to preserve directory structure
|
|
||||||
ADD . /vllm-workspace/
|
ADD . /vllm-workspace/
|
||||||
COPY --from=build /workspace/vllm/*.so /vllm-workspace/vllm/
|
|
||||||
# Install flash attention (from pre-built wheel)
|
|
||||||
RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
|
|
||||||
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
|
|
||||||
# ignore build dependencies installation because we are using pre-complied extensions
|
|
||||||
RUN rm pyproject.toml
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip install . --verbose
|
|
||||||
#################### TEST IMAGE ####################
|
|
||||||
|
|
||||||
|
# install development dependencies (for testing)
|
||||||
#################### RUNTIME BASE IMAGE ####################
|
|
||||||
# We used base cuda image because pytorch installs its own cuda libraries.
|
|
||||||
# However pynccl depends on cuda libraries so we had to switch to the runtime image
|
|
||||||
# In the future it would be nice to get a container with pytorch and cuda without duplicating cuda
|
|
||||||
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 AS vllm-base
|
|
||||||
|
|
||||||
# libnccl required for ray
|
|
||||||
RUN apt-get update -y \
|
|
||||||
&& apt-get install -y python3-pip
|
|
||||||
|
|
||||||
WORKDIR /workspace
|
|
||||||
COPY requirements.txt requirements.txt
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
pip install -r requirements.txt
|
pip install -r requirements-dev.txt
|
||||||
|
|
||||||
# Install flash attention (from pre-built wheel)
|
# doc requires source code
|
||||||
RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
|
# we hide them inside `test_docs/` , so that this source code
|
||||||
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
|
# will not be imported by other tests
|
||||||
|
RUN mkdir test_docs
|
||||||
#################### RUNTIME BASE IMAGE ####################
|
RUN mv docs test_docs/
|
||||||
|
RUN mv vllm test_docs/
|
||||||
|
|
||||||
|
#################### TEST IMAGE ####################
|
||||||
|
|
||||||
#################### OPENAI API SERVER ####################
|
#################### OPENAI API SERVER ####################
|
||||||
# openai api server alternative
|
# openai api server alternative
|
||||||
FROM vllm-base AS vllm-openai
|
FROM vllm-base AS vllm-openai
|
||||||
|
|
||||||
# install additional dependencies for openai api server
|
# install additional dependencies for openai api server
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
pip install accelerate hf_transfer modelscope
|
pip install accelerate hf_transfer modelscope
|
||||||
|
|
||||||
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
|
|
||||||
COPY vllm vllm
|
|
||||||
|
|
||||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||||
|
|
||||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
||||||
|
|||||||
20
Dockerfile.cpu
Normal file
20
Dockerfile.cpu
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform.
|
||||||
|
|
||||||
|
FROM ubuntu:22.04
|
||||||
|
|
||||||
|
RUN apt-get update -y \
|
||||||
|
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \
|
||||||
|
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
|
||||||
|
|
||||||
|
RUN pip install --upgrade pip \
|
||||||
|
&& pip install wheel packaging ninja setuptools>=49.4.0 numpy
|
||||||
|
|
||||||
|
COPY ./ /workspace/vllm
|
||||||
|
|
||||||
|
WORKDIR /workspace/vllm
|
||||||
|
|
||||||
|
RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
|
|
||||||
|
RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
|
||||||
|
|
||||||
|
CMD ["/bin/bash"]
|
||||||
36
Dockerfile.neuron
Normal file
36
Dockerfile.neuron
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# default base image
|
||||||
|
ARG BASE_IMAGE="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference-neuronx:2.1.1-neuronx-py310-sdk2.17.0-ubuntu20.04"
|
||||||
|
|
||||||
|
FROM $BASE_IMAGE
|
||||||
|
|
||||||
|
RUN echo "Base image is $BASE_IMAGE"
|
||||||
|
|
||||||
|
# Install some basic utilities
|
||||||
|
RUN apt-get update && apt-get install python3 python3-pip -y
|
||||||
|
|
||||||
|
### Mount Point ###
|
||||||
|
# When launching the container, mount the code directory to /app
|
||||||
|
ARG APP_MOUNT=/app
|
||||||
|
VOLUME [ ${APP_MOUNT} ]
|
||||||
|
WORKDIR ${APP_MOUNT}
|
||||||
|
|
||||||
|
RUN python3 -m pip install --upgrade pip
|
||||||
|
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
|
||||||
|
RUN python3 -m pip install sentencepiece transformers==4.36.2 -U
|
||||||
|
RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
|
||||||
|
RUN python3 -m pip install --pre neuronx-cc==2.12.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
|
||||||
|
|
||||||
|
COPY ./vllm /app/vllm/vllm
|
||||||
|
COPY ./setup.py /app/vllm/setup.py
|
||||||
|
COPY ./requirements-common.txt /app/vllm/requirements-common.txt
|
||||||
|
COPY ./requirements-neuron.txt /app/vllm/requirements-neuron.txt
|
||||||
|
|
||||||
|
RUN cd /app/vllm \
|
||||||
|
&& python3 -m pip install -U -r requirements-neuron.txt
|
||||||
|
|
||||||
|
ENV VLLM_BUILD_WITH_NEURON 1
|
||||||
|
RUN cd /app/vllm \
|
||||||
|
&& pip install -e . \
|
||||||
|
&& cd ..
|
||||||
|
|
||||||
|
CMD ["/bin/bash"]
|
||||||
@@ -14,7 +14,7 @@ RUN echo "Base image is $BASE_IMAGE"
|
|||||||
ARG FA_GFX_ARCHS="gfx90a;gfx942"
|
ARG FA_GFX_ARCHS="gfx90a;gfx942"
|
||||||
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
|
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
|
||||||
|
|
||||||
ARG FA_BRANCH="3d2b6f5"
|
ARG FA_BRANCH="ae7928c"
|
||||||
RUN echo "FA_BRANCH is $FA_BRANCH"
|
RUN echo "FA_BRANCH is $FA_BRANCH"
|
||||||
|
|
||||||
# whether to build flash-attention
|
# whether to build flash-attention
|
||||||
@@ -23,6 +23,9 @@ RUN echo "FA_BRANCH is $FA_BRANCH"
|
|||||||
# In that case, we need to use the python reference attention implementation in vllm
|
# In that case, we need to use the python reference attention implementation in vllm
|
||||||
ARG BUILD_FA="1"
|
ARG BUILD_FA="1"
|
||||||
|
|
||||||
|
# whether to build triton on rocm
|
||||||
|
ARG BUILD_TRITON="1"
|
||||||
|
|
||||||
# Install some basic utilities
|
# Install some basic utilities
|
||||||
RUN apt-get update && apt-get install python3 python3-pip -y
|
RUN apt-get update && apt-get install python3 python3-pip -y
|
||||||
|
|
||||||
@@ -75,16 +78,24 @@ RUN if [ "$BUILD_FA" = "1" ]; then \
|
|||||||
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
|
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
|
||||||
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
|
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
|
||||||
|
|
||||||
|
# build triton
|
||||||
|
RUN if [ "$BUILD_TRITON" = "1" ]; then \
|
||||||
|
mkdir -p libs \
|
||||||
|
&& cd libs \
|
||||||
|
&& pip uninstall -y triton \
|
||||||
|
&& git clone https://github.com/ROCm/triton.git \
|
||||||
|
&& cd triton/python \
|
||||||
|
&& pip3 install . \
|
||||||
|
&& cd ../..; \
|
||||||
|
fi
|
||||||
|
|
||||||
COPY ./ /app/vllm
|
COPY ./ /app/vllm
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip
|
RUN python3 -m pip install --upgrade pip numba
|
||||||
RUN python3 -m pip install xformers==0.0.23 --no-deps
|
|
||||||
|
|
||||||
RUN cd /app \
|
RUN cd /app \
|
||||||
&& cd vllm \
|
&& cd vllm \
|
||||||
&& pip install -U -r requirements-rocm.txt \
|
&& pip install -U -r requirements-rocm.txt \
|
||||||
&& if [ "$BUILD_FA" = "1" ]; then \
|
|
||||||
bash patch_xformers.rocm.sh; fi \
|
|
||||||
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \
|
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \
|
||||||
&& python3 setup.py install \
|
&& python3 setup.py install \
|
||||||
&& cd ..
|
&& cd ..
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
include LICENSE
|
include LICENSE
|
||||||
include requirements.txt
|
include requirements-common.txt
|
||||||
|
include requirements-cuda.txt
|
||||||
include CMakeLists.txt
|
include CMakeLists.txt
|
||||||
|
|
||||||
recursive-include cmake *
|
recursive-include cmake *
|
||||||
|
|||||||
19
README.md
19
README.md
@@ -14,18 +14,8 @@ Easy, fast, and cheap LLM serving for everyone
|
|||||||
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**The Third vLLM Bay Area Meetup (April 2nd 6pm-8:30pm PT)**
|
|
||||||
|
|
||||||
We are thrilled to announce our third vLLM Meetup!
|
|
||||||
The vLLM team will share recent updates and roadmap.
|
|
||||||
We will also have vLLM collaborators from Roblox coming up to the stage to discuss their experience in deploying LLMs with vLLM.
|
|
||||||
Please register [here](https://robloxandvllmmeetup2024.splashthat.com/) and join us!
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
*Latest News* 🔥
|
*Latest News* 🔥
|
||||||
|
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
|
||||||
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
|
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
|
||||||
- [2024/01] Added ROCm 6.0 support to vLLM.
|
- [2024/01] Added ROCm 6.0 support to vLLM.
|
||||||
- [2023/12] Added ROCm 5.7 support to vLLM.
|
- [2023/12] Added ROCm 5.7 support to vLLM.
|
||||||
@@ -79,16 +69,17 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
|
|||||||
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
|
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
|
||||||
- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.)
|
- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.)
|
||||||
- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.)
|
- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.)
|
||||||
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
- LLaMA, Llama 2, and Meta Llama 3 (`meta-llama/Meta-Llama-3-8B-Instruct`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
||||||
|
- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.)
|
||||||
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
||||||
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
|
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.)
|
||||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
||||||
- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.)
|
- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.)
|
||||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
||||||
- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.)
|
- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.)
|
||||||
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
|
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
|
||||||
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
||||||
- Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.)
|
- Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.)
|
||||||
- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.)
|
- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.)
|
||||||
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
|
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
|
||||||
- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.)
|
- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.)
|
||||||
|
|||||||
@@ -27,8 +27,8 @@ class RequestFuncInput:
|
|||||||
class RequestFuncOutput:
|
class RequestFuncOutput:
|
||||||
generated_text: str = ""
|
generated_text: str = ""
|
||||||
success: bool = False
|
success: bool = False
|
||||||
latency: float = 0
|
latency: float = 0.0
|
||||||
ttft: float = 0 # Time to first token
|
ttft: float = 0.0 # Time to first token
|
||||||
itl: List[float] = field(
|
itl: List[float] = field(
|
||||||
default_factory=list) # List of inter-token latencies
|
default_factory=list) # List of inter-token latencies
|
||||||
prompt_len: int = 0
|
prompt_len: int = 0
|
||||||
@@ -58,23 +58,24 @@ async def async_request_tgi(
|
|||||||
output = RequestFuncOutput()
|
output = RequestFuncOutput()
|
||||||
output.prompt_len = request_func_input.prompt_len
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
ttft = 0
|
ttft = 0.0
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
async with session.post(url=api_url, json=payload) as response:
|
async with session.post(url=api_url, json=payload) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
async for chunk in response.content:
|
async for chunk_bytes in response.content:
|
||||||
chunk = chunk.strip()
|
chunk_bytes = chunk_bytes.strip()
|
||||||
if not chunk:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
|
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
|
||||||
|
"data:")
|
||||||
|
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
timestamp = time.perf_counter()
|
timestamp = time.perf_counter()
|
||||||
# First token
|
# First token
|
||||||
if ttft == 0:
|
if ttft == 0.0:
|
||||||
ttft = time.perf_counter() - st
|
ttft = time.perf_counter() - st
|
||||||
output.ttft = ttft
|
output.ttft = ttft
|
||||||
|
|
||||||
@@ -119,23 +120,25 @@ async def async_request_trt_llm(
|
|||||||
output = RequestFuncOutput()
|
output = RequestFuncOutput()
|
||||||
output.prompt_len = request_func_input.prompt_len
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
ttft = 0
|
ttft = 0.0
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
async with session.post(url=api_url, json=payload) as response:
|
async with session.post(url=api_url, json=payload) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
async for chunk in response.content:
|
async for chunk_bytes in response.content:
|
||||||
chunk = chunk.strip()
|
chunk_bytes = chunk_bytes.strip()
|
||||||
if not chunk:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
|
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
|
||||||
|
"data:")
|
||||||
|
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
|
output.generated_text += data["text_output"]
|
||||||
timestamp = time.perf_counter()
|
timestamp = time.perf_counter()
|
||||||
# First token
|
# First token
|
||||||
if ttft == 0:
|
if ttft == 0.0:
|
||||||
ttft = time.perf_counter() - st
|
ttft = time.perf_counter() - st
|
||||||
output.ttft = ttft
|
output.ttft = ttft
|
||||||
|
|
||||||
@@ -147,11 +150,10 @@ async def async_request_trt_llm(
|
|||||||
most_recent_timestamp = timestamp
|
most_recent_timestamp = timestamp
|
||||||
|
|
||||||
output.latency = most_recent_timestamp - st
|
output.latency = most_recent_timestamp - st
|
||||||
output.generated_text = json.loads(data)["text_output"]
|
|
||||||
output.success = True
|
output.success = True
|
||||||
|
|
||||||
else:
|
else:
|
||||||
output.error = response.reason
|
output.error = response.reason or ""
|
||||||
output.success = False
|
output.success = False
|
||||||
except Exception:
|
except Exception:
|
||||||
output.success = False
|
output.success = False
|
||||||
@@ -195,7 +197,7 @@ async def async_request_deepspeed_mii(
|
|||||||
output.generated_text = parsed_resp["text"][0]
|
output.generated_text = parsed_resp["text"][0]
|
||||||
output.success = True
|
output.success = True
|
||||||
else:
|
else:
|
||||||
output.error = response.reason
|
output.error = response.reason or ""
|
||||||
output.success = False
|
output.success = False
|
||||||
except Exception:
|
except Exception:
|
||||||
output.success = False
|
output.success = False
|
||||||
@@ -234,19 +236,20 @@ async def async_request_openai_completions(
|
|||||||
output.prompt_len = request_func_input.prompt_len
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
generated_text = ""
|
generated_text = ""
|
||||||
ttft = 0
|
ttft = 0.0
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
async with session.post(url=api_url, json=payload,
|
async with session.post(url=api_url, json=payload,
|
||||||
headers=headers) as response:
|
headers=headers) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
async for chunk in response.content:
|
async for chunk_bytes in response.content:
|
||||||
chunk = chunk.strip()
|
chunk_bytes = chunk_bytes.strip()
|
||||||
if not chunk:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
|
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
|
||||||
|
"data: ")
|
||||||
if chunk == "[DONE]":
|
if chunk == "[DONE]":
|
||||||
latency = time.perf_counter() - st
|
latency = time.perf_counter() - st
|
||||||
else:
|
else:
|
||||||
@@ -255,7 +258,7 @@ async def async_request_openai_completions(
|
|||||||
if data["choices"][0]["text"]:
|
if data["choices"][0]["text"]:
|
||||||
timestamp = time.perf_counter()
|
timestamp = time.perf_counter()
|
||||||
# First token
|
# First token
|
||||||
if ttft == 0:
|
if ttft == 0.0:
|
||||||
ttft = time.perf_counter() - st
|
ttft = time.perf_counter() - st
|
||||||
output.ttft = ttft
|
output.ttft = ttft
|
||||||
|
|
||||||
@@ -315,28 +318,30 @@ async def async_request_openai_chat_completions(
|
|||||||
output.prompt_len = request_func_input.prompt_len
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
generated_text = ""
|
generated_text = ""
|
||||||
ttft = 0
|
ttft = 0.0
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
async with session.post(url=api_url, json=payload,
|
async with session.post(url=api_url, json=payload,
|
||||||
headers=headers) as response:
|
headers=headers) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
async for chunk in response.content:
|
async for chunk_bytes in response.content:
|
||||||
chunk = chunk.strip()
|
chunk_bytes = chunk_bytes.strip()
|
||||||
if not chunk:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
|
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
|
||||||
|
"data: ")
|
||||||
if chunk == "[DONE]":
|
if chunk == "[DONE]":
|
||||||
latency = time.perf_counter() - st
|
latency = time.perf_counter() - st
|
||||||
else:
|
else:
|
||||||
timestamp = time.perf_counter()
|
timestamp = time.perf_counter()
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
|
|
||||||
if "content" in data["choices"][0]["delta"]:
|
delta = data["choices"][0]["delta"]
|
||||||
|
if delta.get("content", None):
|
||||||
# First token
|
# First token
|
||||||
if ttft == 0:
|
if ttft == 0.0:
|
||||||
ttft = time.perf_counter() - st
|
ttft = time.perf_counter() - st
|
||||||
output.ttft = ttft
|
output.ttft = ttft
|
||||||
|
|
||||||
@@ -345,8 +350,7 @@ async def async_request_openai_chat_completions(
|
|||||||
output.itl.append(timestamp -
|
output.itl.append(timestamp -
|
||||||
most_recent_timestamp)
|
most_recent_timestamp)
|
||||||
|
|
||||||
generated_text += data["choices"][0]["delta"][
|
generated_text += delta["content"]
|
||||||
"content"]
|
|
||||||
|
|
||||||
most_recent_timestamp = timestamp
|
most_recent_timestamp = timestamp
|
||||||
|
|
||||||
@@ -354,7 +358,7 @@ async def async_request_openai_chat_completions(
|
|||||||
output.success = True
|
output.success = True
|
||||||
output.latency = latency
|
output.latency = latency
|
||||||
else:
|
else:
|
||||||
output.error = response.reason
|
output.error = response.reason or ""
|
||||||
output.success = False
|
output.success = False
|
||||||
except Exception:
|
except Exception:
|
||||||
output.success = False
|
output.success = False
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import torch
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
@@ -24,6 +25,7 @@ def main(args: argparse.Namespace):
|
|||||||
dtype=args.dtype,
|
dtype=args.dtype,
|
||||||
enforce_eager=args.enforce_eager,
|
enforce_eager=args.enforce_eager,
|
||||||
kv_cache_dtype=args.kv_cache_dtype,
|
kv_cache_dtype=args.kv_cache_dtype,
|
||||||
|
quantization_param_path=args.quantization_param_path,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
ray_workers_use_nsight=args.ray_workers_use_nsight,
|
ray_workers_use_nsight=args.ray_workers_use_nsight,
|
||||||
enable_chunked_prefill=args.enable_chunked_prefill,
|
enable_chunked_prefill=args.enable_chunked_prefill,
|
||||||
@@ -67,7 +69,8 @@ def main(args: argparse.Namespace):
|
|||||||
return latency
|
return latency
|
||||||
|
|
||||||
print("Warming up...")
|
print("Warming up...")
|
||||||
run_to_completion(profile_dir=None)
|
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
|
||||||
|
run_to_completion(profile_dir=None)
|
||||||
|
|
||||||
if args.profile:
|
if args.profile:
|
||||||
profile_dir = args.profile_result_dir
|
profile_dir = args.profile_result_dir
|
||||||
@@ -83,7 +86,12 @@ def main(args: argparse.Namespace):
|
|||||||
latencies = []
|
latencies = []
|
||||||
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
|
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
|
||||||
latencies.append(run_to_completion(profile_dir=None))
|
latencies.append(run_to_completion(profile_dir=None))
|
||||||
|
latencies = np.array(latencies)
|
||||||
|
percentages = [10, 25, 50, 75, 90]
|
||||||
|
percentiles = np.percentile(latencies, percentages)
|
||||||
print(f'Avg latency: {np.mean(latencies)} seconds')
|
print(f'Avg latency: {np.mean(latencies)} seconds')
|
||||||
|
for percentage, percentile in zip(percentages, percentiles):
|
||||||
|
print(f'{percentage}% percentile latency: {percentile} seconds')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@@ -94,7 +102,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--tokenizer', type=str, default=None)
|
parser.add_argument('--tokenizer', type=str, default=None)
|
||||||
parser.add_argument('--quantization',
|
parser.add_argument('--quantization',
|
||||||
'-q',
|
'-q',
|
||||||
choices=['awq', 'gptq', 'squeezellm', None],
|
choices=[*QUANTIZATION_METHODS, None],
|
||||||
default=None)
|
default=None)
|
||||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||||
parser.add_argument('--input-len', type=int, default=32)
|
parser.add_argument('--input-len', type=int, default=32)
|
||||||
@@ -105,9 +113,13 @@ if __name__ == '__main__':
|
|||||||
default=1,
|
default=1,
|
||||||
help='Number of generated sequences per prompt.')
|
help='Number of generated sequences per prompt.')
|
||||||
parser.add_argument('--use-beam-search', action='store_true')
|
parser.add_argument('--use-beam-search', action='store_true')
|
||||||
|
parser.add_argument('--num-iters-warmup',
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help='Number of iterations to run for warmup.')
|
||||||
parser.add_argument('--num-iters',
|
parser.add_argument('--num-iters',
|
||||||
type=int,
|
type=int,
|
||||||
default=3,
|
default=30,
|
||||||
help='Number of iterations to run.')
|
help='Number of iterations to run.')
|
||||||
parser.add_argument('--trust-remote-code',
|
parser.add_argument('--trust-remote-code',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
@@ -127,10 +139,23 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kv-cache-dtype",
|
"--kv-cache-dtype",
|
||||||
type=str,
|
type=str,
|
||||||
choices=['auto', 'fp8_e5m2'],
|
choices=['auto', 'fp8'],
|
||||||
default='auto',
|
default='auto',
|
||||||
help=
|
help=
|
||||||
'Data type for kv cache storage. If "auto", will use model data type.')
|
'Data type for kv cache storage. If "auto", will use model data type. '
|
||||||
|
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
||||||
|
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
||||||
|
'common inference criteria.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--quantization-param-path',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='Path to the JSON file containing the KV cache scaling factors. '
|
||||||
|
'This should generally be supplied, when KV cache dtype is FP8. '
|
||||||
|
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
|
||||||
|
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
|
||||||
|
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
|
||||||
|
'instead supported for common inference criteria.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--profile',
|
'--profile',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
@@ -145,16 +170,15 @@ if __name__ == '__main__':
|
|||||||
"--device",
|
"--device",
|
||||||
type=str,
|
type=str,
|
||||||
default="cuda",
|
default="cuda",
|
||||||
choices=["cuda"],
|
choices=["cuda", "cpu"],
|
||||||
help='device type for vLLM execution, supporting CUDA only currently.')
|
help='device type for vLLM execution, supporting CUDA and CPU.')
|
||||||
parser.add_argument('--block-size',
|
parser.add_argument('--block-size',
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
help='block size of key/value cache')
|
help='block size of key/value cache')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--enable-chunked-prefill',
|
'--enable-chunked-prefill',
|
||||||
type=bool,
|
action='store_true',
|
||||||
default=False,
|
|
||||||
help='If True, the prefill requests can be chunked based on the '
|
help='If True, the prefill requests can be chunked based on the '
|
||||||
'max_num_batched_tokens')
|
'max_num_batched_tokens')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -110,7 +110,9 @@ def sample_sonnet_requests(
|
|||||||
prefix_len: int,
|
prefix_len: int,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
) -> List[Tuple[str, str, int, int]]:
|
) -> List[Tuple[str, str, int, int]]:
|
||||||
assert input_len > prefix_len, "input_len must be greater than prefix_len."
|
assert (
|
||||||
|
input_len > prefix_len
|
||||||
|
), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
|
||||||
|
|
||||||
# Load the dataset.
|
# Load the dataset.
|
||||||
with open(dataset_path) as f:
|
with open(dataset_path) as f:
|
||||||
@@ -131,8 +133,9 @@ def sample_sonnet_requests(
|
|||||||
base_message, add_generation_prompt=True, tokenize=False)
|
base_message, add_generation_prompt=True, tokenize=False)
|
||||||
base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids)
|
base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids)
|
||||||
|
|
||||||
assert (input_len > base_prompt_offset
|
assert (
|
||||||
), f"Please set 'args.input-len' higher than {base_prompt_offset}."
|
input_len > base_prompt_offset
|
||||||
|
), f"Please set 'args.sonnet-input-len' higher than {base_prompt_offset}."
|
||||||
num_input_lines = round(
|
num_input_lines = round(
|
||||||
(input_len - base_prompt_offset) / average_poem_len)
|
(input_len - base_prompt_offset) / average_poem_len)
|
||||||
|
|
||||||
@@ -140,7 +143,7 @@ def sample_sonnet_requests(
|
|||||||
# prompt are fixed poem lines.
|
# prompt are fixed poem lines.
|
||||||
assert (
|
assert (
|
||||||
prefix_len > base_prompt_offset
|
prefix_len > base_prompt_offset
|
||||||
), f"Please set 'args.prefix-len' higher than {base_prompt_offset}."
|
), f"Please set 'args.sonnet-prefix-len' higher than {base_prompt_offset}."
|
||||||
|
|
||||||
num_prefix_lines = round(
|
num_prefix_lines = round(
|
||||||
(prefix_len - base_prompt_offset) / average_poem_len)
|
(prefix_len - base_prompt_offset) / average_poem_len)
|
||||||
@@ -373,9 +376,9 @@ def main(args: argparse.Namespace):
|
|||||||
input_requests = sample_sonnet_requests(
|
input_requests = sample_sonnet_requests(
|
||||||
dataset_path=args.dataset_path,
|
dataset_path=args.dataset_path,
|
||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
input_len=args.input_len,
|
input_len=args.sonnet_input_len,
|
||||||
output_len=args.output_len,
|
output_len=args.sonnet_output_len,
|
||||||
prefix_len=args.prefix_len,
|
prefix_len=args.sonnet_prefix_len,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
input_requests = [(prompt, prompt_len, output_len)
|
input_requests = [(prompt, prompt_len, output_len)
|
||||||
@@ -388,9 +391,9 @@ def main(args: argparse.Namespace):
|
|||||||
input_requests = sample_sonnet_requests(
|
input_requests = sample_sonnet_requests(
|
||||||
dataset_path=args.dataset_path,
|
dataset_path=args.dataset_path,
|
||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
input_len=args.input_len,
|
input_len=args.sonnet_input_len,
|
||||||
output_len=args.output_len,
|
output_len=args.sonnet_output_len,
|
||||||
prefix_len=args.prefix_len,
|
prefix_len=args.sonnet_prefix_len,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
input_requests = [(prompt_formatted, prompt_len, output_len)
|
input_requests = [(prompt_formatted, prompt_len, output_len)
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ from tqdm import tqdm
|
|||||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||||
PreTrainedTokenizerBase)
|
PreTrainedTokenizerBase)
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
|
|
||||||
|
|
||||||
def sample_requests(
|
def sample_requests(
|
||||||
dataset_path: str,
|
dataset_path: str,
|
||||||
@@ -29,22 +31,23 @@ def sample_requests(
|
|||||||
dataset = [(data["conversations"][0]["value"],
|
dataset = [(data["conversations"][0]["value"],
|
||||||
data["conversations"][1]["value"]) for data in dataset]
|
data["conversations"][1]["value"]) for data in dataset]
|
||||||
|
|
||||||
# Tokenize the prompts and completions.
|
# Shuffle the dataset.
|
||||||
prompts = [prompt for prompt, _ in dataset]
|
random.shuffle(dataset)
|
||||||
prompt_token_ids = tokenizer(prompts).input_ids
|
|
||||||
completions = [completion for _, completion in dataset]
|
|
||||||
completion_token_ids = tokenizer(completions).input_ids
|
|
||||||
tokenized_dataset = []
|
|
||||||
for i in range(len(dataset)):
|
|
||||||
output_len = len(completion_token_ids[i])
|
|
||||||
if fixed_output_len is not None:
|
|
||||||
output_len = fixed_output_len
|
|
||||||
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
|
|
||||||
|
|
||||||
# Filter out too long sequences.
|
# Filter out sequences that are too long or too short
|
||||||
filtered_dataset: List[Tuple[str, int, int]] = []
|
filtered_dataset: List[Tuple[str, int, int]] = []
|
||||||
for prompt, prompt_token_ids, output_len in tokenized_dataset:
|
for i in range(len(dataset)):
|
||||||
|
if len(filtered_dataset) == num_requests:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Tokenize the prompts and completions.
|
||||||
|
prompt = dataset[i][0]
|
||||||
|
prompt_token_ids = tokenizer(prompt).input_ids
|
||||||
|
completion = dataset[i][1]
|
||||||
|
completion_token_ids = tokenizer(completion).input_ids
|
||||||
prompt_len = len(prompt_token_ids)
|
prompt_len = len(prompt_token_ids)
|
||||||
|
output_len = len(completion_token_ids
|
||||||
|
) if fixed_output_len is None else fixed_output_len
|
||||||
if prompt_len < 4 or output_len < 4:
|
if prompt_len < 4 or output_len < 4:
|
||||||
# Prune too short sequences.
|
# Prune too short sequences.
|
||||||
continue
|
continue
|
||||||
@@ -53,9 +56,7 @@ def sample_requests(
|
|||||||
continue
|
continue
|
||||||
filtered_dataset.append((prompt, prompt_len, output_len))
|
filtered_dataset.append((prompt, prompt_len, output_len))
|
||||||
|
|
||||||
# Sample the requests.
|
return filtered_dataset
|
||||||
sampled_requests = random.sample(filtered_dataset, num_requests)
|
|
||||||
return sampled_requests
|
|
||||||
|
|
||||||
|
|
||||||
def run_vllm(
|
def run_vllm(
|
||||||
@@ -72,26 +73,34 @@ def run_vllm(
|
|||||||
max_model_len: Optional[int],
|
max_model_len: Optional[int],
|
||||||
enforce_eager: bool,
|
enforce_eager: bool,
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
|
quantization_param_path: Optional[str],
|
||||||
device: str,
|
device: str,
|
||||||
enable_prefix_caching: bool,
|
enable_prefix_caching: bool,
|
||||||
|
enable_chunked_prefill: bool,
|
||||||
|
max_num_batched_tokens: int,
|
||||||
gpu_memory_utilization: float = 0.9,
|
gpu_memory_utilization: float = 0.9,
|
||||||
download_dir: Optional[str] = None,
|
download_dir: Optional[str] = None,
|
||||||
) -> float:
|
) -> float:
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
llm = LLM(model=model,
|
llm = LLM(
|
||||||
tokenizer=tokenizer,
|
model=model,
|
||||||
quantization=quantization,
|
tokenizer=tokenizer,
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
quantization=quantization,
|
||||||
seed=seed,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
trust_remote_code=trust_remote_code,
|
seed=seed,
|
||||||
dtype=dtype,
|
trust_remote_code=trust_remote_code,
|
||||||
max_model_len=max_model_len,
|
dtype=dtype,
|
||||||
gpu_memory_utilization=gpu_memory_utilization,
|
max_model_len=max_model_len,
|
||||||
enforce_eager=enforce_eager,
|
gpu_memory_utilization=gpu_memory_utilization,
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
enforce_eager=enforce_eager,
|
||||||
device=device,
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
enable_prefix_caching=enable_prefix_caching,
|
quantization_param_path=quantization_param_path,
|
||||||
download_dir=download_dir)
|
device=device,
|
||||||
|
enable_prefix_caching=enable_prefix_caching,
|
||||||
|
download_dir=download_dir,
|
||||||
|
enable_chunked_prefill=enable_chunked_prefill,
|
||||||
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
for prompt, _, output_len in requests:
|
for prompt, _, output_len in requests:
|
||||||
@@ -212,14 +221,15 @@ def main(args: argparse.Namespace):
|
|||||||
args.output_len)
|
args.output_len)
|
||||||
|
|
||||||
if args.backend == "vllm":
|
if args.backend == "vllm":
|
||||||
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
|
elapsed_time = run_vllm(
|
||||||
args.quantization, args.tensor_parallel_size,
|
requests, args.model, args.tokenizer, args.quantization,
|
||||||
args.seed, args.n, args.use_beam_search,
|
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
|
||||||
args.trust_remote_code, args.dtype,
|
args.trust_remote_code, args.dtype, args.max_model_len,
|
||||||
args.max_model_len, args.enforce_eager,
|
args.enforce_eager, args.kv_cache_dtype,
|
||||||
args.kv_cache_dtype, args.device,
|
args.quantization_param_path, args.device,
|
||||||
args.enable_prefix_caching,
|
args.enable_prefix_caching, args.enable_chunked_prefill,
|
||||||
args.gpu_memory_utilization, args.download_dir)
|
args.max_num_batched_tokens, args.gpu_memory_utilization,
|
||||||
|
args.download_dir)
|
||||||
elif args.backend == "hf":
|
elif args.backend == "hf":
|
||||||
assert args.tensor_parallel_size == 1
|
assert args.tensor_parallel_size == 1
|
||||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||||
@@ -259,7 +269,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--tokenizer", type=str, default=None)
|
parser.add_argument("--tokenizer", type=str, default=None)
|
||||||
parser.add_argument('--quantization',
|
parser.add_argument('--quantization',
|
||||||
'-q',
|
'-q',
|
||||||
choices=['awq', 'gptq', 'squeezellm', None],
|
choices=[*QUANTIZATION_METHODS, None],
|
||||||
default=None)
|
default=None)
|
||||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||||
parser.add_argument("--n",
|
parser.add_argument("--n",
|
||||||
@@ -306,20 +316,41 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kv-cache-dtype",
|
"--kv-cache-dtype",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["auto", "fp8_e5m2"],
|
choices=["auto", "fp8"],
|
||||||
default="auto",
|
default="auto",
|
||||||
help=
|
help=
|
||||||
'Data type for kv cache storage. If "auto", will use model data type.')
|
'Data type for kv cache storage. If "auto", will use model data type. '
|
||||||
|
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
||||||
|
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
||||||
|
'common inference criteria.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--quantization-param-path',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='Path to the JSON file containing the KV cache scaling factors. '
|
||||||
|
'This should generally be supplied, when KV cache dtype is FP8. '
|
||||||
|
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
|
||||||
|
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
|
||||||
|
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
|
||||||
|
'instead supported for common inference criteria.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device",
|
"--device",
|
||||||
type=str,
|
type=str,
|
||||||
default="cuda",
|
default="cuda",
|
||||||
choices=["cuda"],
|
choices=["cuda", "cpu"],
|
||||||
help='device type for vLLM execution, supporting CUDA only currently.')
|
help='device type for vLLM execution, supporting CUDA and CPU.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-prefix-caching",
|
"--enable-prefix-caching",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help="enable automatic prefix caching for vLLM backend.")
|
help="enable automatic prefix caching for vLLM backend.")
|
||||||
|
parser.add_argument("--enable-chunked-prefill",
|
||||||
|
action='store_true',
|
||||||
|
help="enable chunked prefill for vLLM backend.")
|
||||||
|
parser.add_argument('--max-num-batched-tokens',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help='maximum number of batched tokens per '
|
||||||
|
'iteration')
|
||||||
parser.add_argument('--download-dir',
|
parser.add_argument('--download-dir',
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
|
|||||||
302
benchmarks/kernels/benchmark_aqlm.py
Normal file
302
benchmarks/kernels/benchmark_aqlm.py
Normal file
@@ -0,0 +1,302 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from vllm._C import ops
|
||||||
|
from vllm.model_executor.layers.quantization.aqlm import (
|
||||||
|
dequantize_weight, generic_dequantize_gemm, get_int_dtype,
|
||||||
|
optimized_dequantize_gemm)
|
||||||
|
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||||
|
|
||||||
|
|
||||||
|
def torch_mult(
|
||||||
|
input: torch.Tensor, # [..., in_features]
|
||||||
|
weights: torch.Tensor,
|
||||||
|
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
output = F.linear(input, weights)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def dequant_out_scale(
|
||||||
|
input: torch.Tensor, # [..., in_features]
|
||||||
|
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
|
||||||
|
codebooks: torch.
|
||||||
|
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||||
|
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||||
|
output_partition_sizes: torch.IntTensor,
|
||||||
|
bias: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||||
|
|
||||||
|
if bias is None:
|
||||||
|
output = F.linear(input, weights, bias)
|
||||||
|
orig_shape = output.shape
|
||||||
|
flattened_output = output.view(-1, output.size(-1))
|
||||||
|
f_scales = scales.view(-1, scales.shape[0])
|
||||||
|
b_scales = f_scales.expand(flattened_output.shape[0], -1)
|
||||||
|
flattened_output *= b_scales
|
||||||
|
return flattened_output.view(orig_shape)
|
||||||
|
else:
|
||||||
|
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
|
||||||
|
-1, weights.shape[1])
|
||||||
|
weights *= b_scales
|
||||||
|
return F.linear(input, weights, bias)
|
||||||
|
|
||||||
|
|
||||||
|
def dequant_weight_scale(
|
||||||
|
input: torch.Tensor, # [..., in_features]
|
||||||
|
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
|
||||||
|
codebooks: torch.
|
||||||
|
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||||
|
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||||
|
output_partition_sizes: torch.IntTensor,
|
||||||
|
bias: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||||
|
|
||||||
|
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
|
||||||
|
-1, weights.shape[1])
|
||||||
|
weights *= b_scales
|
||||||
|
return F.linear(input, weights, bias)
|
||||||
|
|
||||||
|
|
||||||
|
def dequant_no_scale(
|
||||||
|
input: torch.Tensor, # [..., in_features]
|
||||||
|
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
|
||||||
|
codebooks: torch.
|
||||||
|
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||||
|
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||||
|
output_partition_sizes: torch.IntTensor,
|
||||||
|
bias: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||||
|
|
||||||
|
return F.linear(input, weights, bias)
|
||||||
|
|
||||||
|
|
||||||
|
# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against
|
||||||
|
# the generic pytorch version.
|
||||||
|
# Just visual comparison.
|
||||||
|
def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None:
|
||||||
|
|
||||||
|
n = parts.sum().item()
|
||||||
|
|
||||||
|
device = torch.device('cuda:0')
|
||||||
|
|
||||||
|
code_range = (1 << bits) // 2
|
||||||
|
ingroups = 8
|
||||||
|
|
||||||
|
codes = torch.randint(-code_range,
|
||||||
|
code_range,
|
||||||
|
size=(n, k // ingroups, nbooks),
|
||||||
|
dtype=get_int_dtype(bits),
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
|
||||||
|
dtype=torch.float16,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for index in range(16):
|
||||||
|
for i in range(8):
|
||||||
|
for book in range(nbooks):
|
||||||
|
codebooks[book, index, 0, i] = count * (10**book)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
print("codes shape", codes.shape)
|
||||||
|
|
||||||
|
for i in range(16):
|
||||||
|
for book in range(nbooks):
|
||||||
|
codes[0, i, book] = i
|
||||||
|
codes[0, -i, book] = i
|
||||||
|
|
||||||
|
weights = dequantize_weight(codes, codebooks, None)
|
||||||
|
weights2 = ops.aqlm_dequant(codes, codebooks, parts)
|
||||||
|
|
||||||
|
print("weights shape:", weights.shape)
|
||||||
|
print("weights2 shape:", weights2.shape)
|
||||||
|
|
||||||
|
print("weights are:", weights)
|
||||||
|
print("weights2 are:", weights2)
|
||||||
|
|
||||||
|
print("first 128 weights are", weights[0, 0:128].to(torch.int32))
|
||||||
|
print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32))
|
||||||
|
|
||||||
|
print("last 128 weights are", weights[0, -128:])
|
||||||
|
print("last 128 weights2 are:", weights2[0, -128:])
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Benchmark aqlm performance.")
|
||||||
|
|
||||||
|
# Add arguments
|
||||||
|
parser.add_argument("--nbooks",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of codebooks (default: 1)")
|
||||||
|
parser.add_argument("--bits",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="Number of bits per code element (default: 16)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--test",
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help="Run the decompression/dequant tester rather than benchmarking "
|
||||||
|
"(default: False)")
|
||||||
|
|
||||||
|
# Parse the arguments
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Extract values
|
||||||
|
nbooks = args.nbooks
|
||||||
|
bits = args.bits
|
||||||
|
|
||||||
|
if args.test:
|
||||||
|
dequant_test(4096, torch.tensor((4096, )), nbooks, bits)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Otherwise, benchmark.
|
||||||
|
methods = [
|
||||||
|
ops.aqlm_gemm,
|
||||||
|
dequant_out_scale,
|
||||||
|
generic_dequantize_gemm,
|
||||||
|
optimized_dequantize_gemm,
|
||||||
|
dequant_weight_scale,
|
||||||
|
torch_mult,
|
||||||
|
dequant_no_scale,
|
||||||
|
]
|
||||||
|
|
||||||
|
filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv"
|
||||||
|
print(f"writing benchmarks to file {filename}")
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
sys.stdout = f
|
||||||
|
|
||||||
|
print('m | k | n | n parts', end='')
|
||||||
|
for method in methods:
|
||||||
|
print(f" | {method.__name__.replace('_', ' ')} (µs)", end='')
|
||||||
|
print('')
|
||||||
|
|
||||||
|
# These are reasonable prefill sizes.
|
||||||
|
ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )),
|
||||||
|
(4096, (11008, 11008)), (11008, (4096, )))
|
||||||
|
|
||||||
|
# reasonable ranges for m.
|
||||||
|
for m in [
|
||||||
|
1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112,
|
||||||
|
128, 256, 512, 1024, 1536, 2048, 3072, 4096
|
||||||
|
]:
|
||||||
|
print(f'{m}', file=sys.__stdout__)
|
||||||
|
for ksp in ksandpartions:
|
||||||
|
run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits,
|
||||||
|
methods)
|
||||||
|
|
||||||
|
sys.stdout = sys.__stdout__
|
||||||
|
|
||||||
|
|
||||||
|
def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int,
|
||||||
|
methods):
|
||||||
|
|
||||||
|
# I didn't see visible improvements from increasing these, but feel free :)
|
||||||
|
num_warmup_trials = 1
|
||||||
|
num_trials = 1
|
||||||
|
|
||||||
|
num_calls = 100
|
||||||
|
|
||||||
|
# warmup.
|
||||||
|
for method in methods:
|
||||||
|
for _ in range(num_warmup_trials):
|
||||||
|
run_timing(
|
||||||
|
num_calls=num_calls,
|
||||||
|
m=m,
|
||||||
|
k=k,
|
||||||
|
parts=parts,
|
||||||
|
nbooks=nbooks,
|
||||||
|
bits=bits,
|
||||||
|
method=method,
|
||||||
|
)
|
||||||
|
|
||||||
|
n = parts.sum().item()
|
||||||
|
print(f'{m} | {k} | {n} | {parts.tolist()}', end='')
|
||||||
|
|
||||||
|
for method in methods:
|
||||||
|
best_time_us = 1e20
|
||||||
|
for _ in range(num_trials):
|
||||||
|
kernel_dur_ms = run_timing(
|
||||||
|
num_calls=num_calls,
|
||||||
|
m=m,
|
||||||
|
k=k,
|
||||||
|
parts=parts,
|
||||||
|
nbooks=nbooks,
|
||||||
|
bits=bits,
|
||||||
|
method=method,
|
||||||
|
)
|
||||||
|
|
||||||
|
kernel_dur_us = 1000 * kernel_dur_ms
|
||||||
|
|
||||||
|
if kernel_dur_us < best_time_us:
|
||||||
|
best_time_us = kernel_dur_us
|
||||||
|
|
||||||
|
print(f' | {kernel_dur_us:.0f}', end='')
|
||||||
|
|
||||||
|
print('')
|
||||||
|
|
||||||
|
|
||||||
|
def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor,
|
||||||
|
nbooks: int, bits: int, method) -> float:
|
||||||
|
|
||||||
|
n = parts.sum().item()
|
||||||
|
|
||||||
|
device = torch.device('cuda:0')
|
||||||
|
|
||||||
|
input = torch.randn((1, m, k), dtype=torch.float16, device=device)
|
||||||
|
|
||||||
|
code_range = (1 << bits) // 2
|
||||||
|
ingroups = 8
|
||||||
|
|
||||||
|
codes = torch.randint(-code_range,
|
||||||
|
code_range,
|
||||||
|
size=(n, k // ingroups, nbooks),
|
||||||
|
dtype=get_int_dtype(bits),
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
|
||||||
|
dtype=torch.float16,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)
|
||||||
|
|
||||||
|
# for comparison to just a pytorch mult.
|
||||||
|
weights = torch.randn((n, k), dtype=torch.float16, device=device)
|
||||||
|
|
||||||
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
|
start_event.record()
|
||||||
|
|
||||||
|
if method is torch_mult:
|
||||||
|
for i in range(num_calls):
|
||||||
|
torch_mult(input, weights, scales)
|
||||||
|
else:
|
||||||
|
for i in range(num_calls):
|
||||||
|
method(input, codes, codebooks, scales, parts, None)
|
||||||
|
|
||||||
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
|
||||||
|
dur_ms = start_event.elapsed_time(end_event) / num_calls
|
||||||
|
return dur_ms
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
@@ -5,7 +5,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm._C import ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
|
||||||
|
|
||||||
NUM_BLOCKS = 1024
|
NUM_BLOCKS = 1024
|
||||||
@@ -97,6 +97,9 @@ def main(
|
|||||||
torch.cuda.cudart().cudaProfilerStart()
|
torch.cuda.cudart().cudaProfilerStart()
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Using default kv_scale
|
||||||
|
kv_scale = 1.0
|
||||||
|
|
||||||
for _ in range(num_iters):
|
for _ in range(num_iters):
|
||||||
if version == "v1":
|
if version == "v1":
|
||||||
ops.paged_attention_v1(
|
ops.paged_attention_v1(
|
||||||
@@ -112,6 +115,7 @@ def main(
|
|||||||
max_context_len,
|
max_context_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
|
kv_scale,
|
||||||
)
|
)
|
||||||
elif version == "v2":
|
elif version == "v2":
|
||||||
ops.paged_attention_v2(
|
ops.paged_attention_v2(
|
||||||
@@ -130,6 +134,7 @@ def main(
|
|||||||
max_context_len,
|
max_context_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
|
kv_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid version: {version}")
|
raise ValueError(f"Invalid version: {version}")
|
||||||
@@ -179,11 +184,13 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kv-cache-dtype",
|
"--kv-cache-dtype",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["auto", "fp8_e5m2"],
|
choices=["auto", "fp8"],
|
||||||
default="auto",
|
default="auto",
|
||||||
help=
|
help=
|
||||||
'Data type for kv cache storage. If "auto", will use model data type.')
|
'Data type for kv cache storage. If "auto", will use model data type. '
|
||||||
parser.add_argument("--device", type=str, choices=["cuda"], default="cuda")
|
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
||||||
|
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
||||||
|
'common inference criteria.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
|
|||||||
90
cmake/cpu_extension.cmake
Normal file
90
cmake/cpu_extension.cmake
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||||
|
|
||||||
|
#
|
||||||
|
# Define environment variables for special configurations
|
||||||
|
#
|
||||||
|
if(DEFINED ENV{VLLM_CPU_AVX512BF16})
|
||||||
|
set(ENABLE_AVX512BF16 ON)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
include_directories("${CMAKE_SOURCE_DIR}/csrc")
|
||||||
|
|
||||||
|
#
|
||||||
|
# Check the compile flags
|
||||||
|
#
|
||||||
|
list(APPEND CXX_COMPILE_FLAGS
|
||||||
|
"-fopenmp"
|
||||||
|
"-DVLLM_CPU_EXTENSION")
|
||||||
|
|
||||||
|
execute_process(COMMAND cat /proc/cpuinfo
|
||||||
|
RESULT_VARIABLE CPUINFO_RET
|
||||||
|
OUTPUT_VARIABLE CPUINFO)
|
||||||
|
|
||||||
|
if (NOT CPUINFO_RET EQUAL 0)
|
||||||
|
message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
function (find_isa CPUINFO TARGET OUT)
|
||||||
|
string(FIND ${CPUINFO} ${TARGET} ISA_FOUND)
|
||||||
|
if(NOT ISA_FOUND EQUAL -1)
|
||||||
|
set(${OUT} ON PARENT_SCOPE)
|
||||||
|
else()
|
||||||
|
set(${OUT} OFF PARENT_SCOPE)
|
||||||
|
endif()
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
|
||||||
|
|
||||||
|
if (AVX512_FOUND)
|
||||||
|
list(APPEND CXX_COMPILE_FLAGS
|
||||||
|
"-mavx512f"
|
||||||
|
"-mavx512vl"
|
||||||
|
"-mavx512bw"
|
||||||
|
"-mavx512dq")
|
||||||
|
|
||||||
|
find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND)
|
||||||
|
if (AVX512BF16_FOUND OR ENABLE_AVX512BF16)
|
||||||
|
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
||||||
|
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
||||||
|
list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
|
||||||
|
else()
|
||||||
|
message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.")
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "vLLM CPU backend requires AVX512 ISA support.")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Define extension targets
|
||||||
|
#
|
||||||
|
|
||||||
|
#
|
||||||
|
# _C extension
|
||||||
|
#
|
||||||
|
set(VLLM_EXT_SRC
|
||||||
|
"csrc/cpu/activation.cpp"
|
||||||
|
"csrc/cpu/attention.cpp"
|
||||||
|
"csrc/cpu/cache.cpp"
|
||||||
|
"csrc/cpu/layernorm.cpp"
|
||||||
|
"csrc/cpu/pos_encoding.cpp"
|
||||||
|
"csrc/cpu/pybind.cpp")
|
||||||
|
|
||||||
|
define_gpu_extension_target(
|
||||||
|
_C
|
||||||
|
DESTINATION vllm
|
||||||
|
LANGUAGE CXX
|
||||||
|
SOURCES ${VLLM_EXT_SRC}
|
||||||
|
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
|
||||||
|
WITH_SOABI
|
||||||
|
)
|
||||||
|
|
||||||
|
add_custom_target(default)
|
||||||
|
message(STATUS "Enabling C extension.")
|
||||||
|
add_dependencies(default _C)
|
||||||
|
|
||||||
@@ -101,6 +101,13 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
|
|||||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
|
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
|
||||||
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
|
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
|
||||||
endif()
|
endif()
|
||||||
|
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
|
||||||
|
list(REMOVE_ITEM GPU_FLAGS
|
||||||
|
"-D__CUDA_NO_HALF_OPERATORS__"
|
||||||
|
"-D__CUDA_NO_HALF_CONVERSIONS__"
|
||||||
|
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
|
||||||
|
"-D__CUDA_NO_HALF2_OPERATORS__")
|
||||||
|
endif()
|
||||||
|
|
||||||
elseif(${GPU_LANG} STREQUAL "HIP")
|
elseif(${GPU_LANG} STREQUAL "HIP")
|
||||||
#
|
#
|
||||||
@@ -112,6 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
|
|||||||
|
|
||||||
list(APPEND GPU_FLAGS
|
list(APPEND GPU_FLAGS
|
||||||
"-DUSE_ROCM"
|
"-DUSE_ROCM"
|
||||||
|
"-DENABLE_FP8_E4M3"
|
||||||
"-U__HIP_NO_HALF_CONVERSIONS__"
|
"-U__HIP_NO_HALF_CONVERSIONS__"
|
||||||
"-U__HIP_NO_HALF_OPERATORS__"
|
"-U__HIP_NO_HALF_OPERATORS__"
|
||||||
"-fno-gpu-rdc")
|
"-fno-gpu-rdc")
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ DEFAULT_CONDA_PATTERNS = {
|
|||||||
"magma",
|
"magma",
|
||||||
"triton",
|
"triton",
|
||||||
"optree",
|
"optree",
|
||||||
|
"nccl",
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFAULT_PIP_PATTERNS = {
|
DEFAULT_PIP_PATTERNS = {
|
||||||
@@ -73,6 +74,7 @@ DEFAULT_PIP_PATTERNS = {
|
|||||||
"triton",
|
"triton",
|
||||||
"optree",
|
"optree",
|
||||||
"onnx",
|
"onnx",
|
||||||
|
"nccl",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,4 +4,4 @@
|
|||||||
#include "dtype_float16.cuh"
|
#include "dtype_float16.cuh"
|
||||||
#include "dtype_float32.cuh"
|
#include "dtype_float32.cuh"
|
||||||
#include "dtype_bfloat16.cuh"
|
#include "dtype_bfloat16.cuh"
|
||||||
#include "dtype_fp8_e5m2.cuh"
|
#include "dtype_fp8.cuh"
|
||||||
|
|||||||
@@ -22,12 +22,26 @@
|
|||||||
|
|
||||||
#include "attention_dtypes.h"
|
#include "attention_dtypes.h"
|
||||||
#include "attention_utils.cuh"
|
#include "attention_utils.cuh"
|
||||||
#ifdef ENABLE_FP8_E5M2
|
|
||||||
|
#if defined(ENABLE_FP8_E5M2)
|
||||||
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
||||||
|
#elif defined(ENABLE_FP8_E4M3)
|
||||||
|
#include "../quantization/fp8/amd_detail/quant_utils.cuh"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
#include <hip/hip_bf16.h>
|
||||||
|
typedef __hip_bfloat16 __nv_bfloat16;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#define WARP_SIZE 32
|
||||||
|
#else
|
||||||
|
#define WARP_SIZE warpSize
|
||||||
|
#endif
|
||||||
|
|
||||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||||
@@ -78,7 +92,7 @@ template<
|
|||||||
int HEAD_SIZE,
|
int HEAD_SIZE,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
int NUM_THREADS,
|
int NUM_THREADS,
|
||||||
bool IS_FP8_E5M2_KV_CACHE,
|
bool IS_FP8_KV_CACHE,
|
||||||
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
||||||
__device__ void paged_attention_kernel(
|
__device__ void paged_attention_kernel(
|
||||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
@@ -95,7 +109,8 @@ __device__ void paged_attention_kernel(
|
|||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
const int q_stride,
|
const int q_stride,
|
||||||
const int kv_block_stride,
|
const int kv_block_stride,
|
||||||
const int kv_head_stride) {
|
const int kv_head_stride,
|
||||||
|
const float kv_scale) {
|
||||||
const int seq_idx = blockIdx.y;
|
const int seq_idx = blockIdx.y;
|
||||||
const int partition_idx = blockIdx.z;
|
const int partition_idx = blockIdx.z;
|
||||||
const int max_num_partitions = gridDim.z;
|
const int max_num_partitions = gridDim.z;
|
||||||
@@ -142,7 +157,7 @@ __device__ void paged_attention_kernel(
|
|||||||
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
||||||
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||||
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||||
#ifdef ENABLE_FP8_E5M2
|
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
|
||||||
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
|
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@@ -208,11 +223,16 @@ __device__ void paged_attention_kernel(
|
|||||||
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
||||||
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
||||||
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
||||||
if constexpr (IS_FP8_E5M2_KV_CACHE) {
|
if constexpr (IS_FP8_KV_CACHE) {
|
||||||
#ifdef ENABLE_FP8_E5M2
|
#if defined(ENABLE_FP8_E5M2)
|
||||||
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||||
// Vector conversion from Quant_vec to K_vec.
|
// Vector conversion from Quant_vec to K_vec.
|
||||||
k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
|
k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
|
||||||
|
#elif defined(ENABLE_FP8_E4M3)
|
||||||
|
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||||
|
// Vector conversion from Quant_vec to K_vec. Use scaled_vec_conversion to convert FP8_E4M3 quantized k
|
||||||
|
// cache vec to k vec in higher precision (FP16, BFloat16, etc.)
|
||||||
|
k_vecs[j] = fp8_e4m3::scaled_vec_conversion<K_vec, Quant_vec>(k_vec_quant, kv_scale);
|
||||||
#else
|
#else
|
||||||
assert(false);
|
assert(false);
|
||||||
#endif
|
#endif
|
||||||
@@ -292,7 +312,7 @@ __device__ void paged_attention_kernel(
|
|||||||
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
||||||
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||||
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||||
#ifdef ENABLE_FP8_E5M2
|
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
|
||||||
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
|
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
|
||||||
#endif
|
#endif
|
||||||
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
||||||
@@ -328,11 +348,16 @@ __device__ void paged_attention_kernel(
|
|||||||
if (row_idx < HEAD_SIZE) {
|
if (row_idx < HEAD_SIZE) {
|
||||||
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
||||||
V_vec v_vec;
|
V_vec v_vec;
|
||||||
if constexpr (IS_FP8_E5M2_KV_CACHE) {
|
if constexpr (IS_FP8_KV_CACHE) {
|
||||||
#ifdef ENABLE_FP8_E5M2
|
#if defined(ENABLE_FP8_E5M2)
|
||||||
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||||
// Vector conversion from V_quant_vec to V_vec.
|
// Vector conversion from V_quant_vec to V_vec.
|
||||||
v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
|
v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
|
||||||
|
#elif defined(ENABLE_FP8_E4M3)
|
||||||
|
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||||
|
// Vector conversion from V_quant_vec to V_vec. Use scaled_vec_conversion to convert
|
||||||
|
// FP8_E4M3 quantized v cache vec to v vec in higher precision (FP16, BFloat16, etc.)
|
||||||
|
v_vec = fp8_e4m3::scaled_vec_conversion<V_vec, V_quant_vec>(v_quant_vec, kv_scale);
|
||||||
#else
|
#else
|
||||||
assert(false);
|
assert(false);
|
||||||
#endif
|
#endif
|
||||||
@@ -423,7 +448,7 @@ template<
|
|||||||
int HEAD_SIZE,
|
int HEAD_SIZE,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
int NUM_THREADS,
|
int NUM_THREADS,
|
||||||
bool IS_FP8_E5M2_KV_CACHE>
|
bool IS_FP8_KV_CACHE>
|
||||||
__global__ void paged_attention_v1_kernel(
|
__global__ void paged_attention_v1_kernel(
|
||||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
@@ -437,11 +462,12 @@ __global__ void paged_attention_v1_kernel(
|
|||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
const int q_stride,
|
const int q_stride,
|
||||||
const int kv_block_stride,
|
const int kv_block_stride,
|
||||||
const int kv_head_stride) {
|
const int kv_head_stride,
|
||||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
|
const float kv_scale) {
|
||||||
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE>(
|
||||||
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
||||||
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
|
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
|
||||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
|
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||||
@@ -451,7 +477,7 @@ template<
|
|||||||
int HEAD_SIZE,
|
int HEAD_SIZE,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
int NUM_THREADS,
|
int NUM_THREADS,
|
||||||
bool IS_FP8_E5M2_KV_CACHE,
|
bool IS_FP8_KV_CACHE,
|
||||||
int PARTITION_SIZE>
|
int PARTITION_SIZE>
|
||||||
__global__ void paged_attention_v2_kernel(
|
__global__ void paged_attention_v2_kernel(
|
||||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
@@ -468,11 +494,12 @@ __global__ void paged_attention_v2_kernel(
|
|||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
const int q_stride,
|
const int q_stride,
|
||||||
const int kv_block_stride,
|
const int kv_block_stride,
|
||||||
const int kv_head_stride) {
|
const int kv_head_stride,
|
||||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
|
const float kv_scale) {
|
||||||
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE, PARTITION_SIZE>(
|
||||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
||||||
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
|
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
|
||||||
q_stride, kv_block_stride, kv_head_stride);
|
q_stride, kv_block_stride, kv_head_stride, kv_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Grid: (num_heads, num_seqs).
|
// Grid: (num_heads, num_seqs).
|
||||||
@@ -579,9 +606,9 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
||||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
||||||
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||||
IS_FP8_E5M2_KV_CACHE>), shared_mem_size); \
|
IS_FP8_KV_CACHE>), shared_mem_size); \
|
||||||
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||||
IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
|
IS_FP8_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
|
||||||
out_ptr, \
|
out_ptr, \
|
||||||
query_ptr, \
|
query_ptr, \
|
||||||
key_cache_ptr, \
|
key_cache_ptr, \
|
||||||
@@ -594,14 +621,15 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
alibi_slopes_ptr, \
|
alibi_slopes_ptr, \
|
||||||
q_stride, \
|
q_stride, \
|
||||||
kv_block_stride, \
|
kv_block_stride, \
|
||||||
kv_head_stride);
|
kv_head_stride, \
|
||||||
|
kv_scale);
|
||||||
|
|
||||||
// TODO(woosuk): Tune NUM_THREADS.
|
// TODO(woosuk): Tune NUM_THREADS.
|
||||||
template<
|
template<
|
||||||
typename T,
|
typename T,
|
||||||
typename CACHE_T,
|
typename CACHE_T,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
bool IS_FP8_E5M2_KV_CACHE,
|
bool IS_FP8_KV_CACHE,
|
||||||
int NUM_THREADS = 128>
|
int NUM_THREADS = 128>
|
||||||
void paged_attention_v1_launcher(
|
void paged_attention_v1_launcher(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out,
|
||||||
@@ -613,7 +641,8 @@ void paged_attention_v1_launcher(
|
|||||||
torch::Tensor& block_tables,
|
torch::Tensor& block_tables,
|
||||||
torch::Tensor& context_lens,
|
torch::Tensor& context_lens,
|
||||||
int max_context_len,
|
int max_context_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
|
float kv_scale) {
|
||||||
int num_seqs = query.size(0);
|
int num_seqs = query.size(0);
|
||||||
int num_heads = query.size(1);
|
int num_heads = query.size(1);
|
||||||
int head_size = query.size(2);
|
int head_size = query.size(2);
|
||||||
@@ -677,8 +706,8 @@ void paged_attention_v1_launcher(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
|
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
|
||||||
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
|
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE>( \
|
||||||
out, \
|
out, \
|
||||||
query, \
|
query, \
|
||||||
key_cache, \
|
key_cache, \
|
||||||
@@ -688,20 +717,21 @@ void paged_attention_v1_launcher(
|
|||||||
block_tables, \
|
block_tables, \
|
||||||
context_lens, \
|
context_lens, \
|
||||||
max_context_len, \
|
max_context_len, \
|
||||||
alibi_slopes);
|
alibi_slopes, \
|
||||||
|
kv_scale);
|
||||||
|
|
||||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||||
// 1, 2, 4, 64, 128, 256.
|
// 1, 2, 4, 64, 128, 256.
|
||||||
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \
|
||||||
switch (block_size) { \
|
switch (block_size) { \
|
||||||
case 8: \
|
case 8: \
|
||||||
CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
|
CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
case 16: \
|
case 16: \
|
||||||
CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
|
CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
case 32: \
|
case 32: \
|
||||||
CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
|
CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
default: \
|
default: \
|
||||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||||
@@ -720,7 +750,8 @@ void paged_attention_v1(
|
|||||||
int block_size,
|
int block_size,
|
||||||
int max_context_len,
|
int max_context_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype) {
|
const std::string& kv_cache_dtype,
|
||||||
|
float kv_scale) {
|
||||||
if (kv_cache_dtype == "auto") {
|
if (kv_cache_dtype == "auto") {
|
||||||
if (query.dtype() == at::ScalarType::Float) {
|
if (query.dtype() == at::ScalarType::Float) {
|
||||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
|
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
|
||||||
@@ -731,7 +762,7 @@ void paged_attention_v1(
|
|||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||||
}
|
}
|
||||||
} else if (kv_cache_dtype == "fp8_e5m2") {
|
} else if (kv_cache_dtype == "fp8") {
|
||||||
if (query.dtype() == at::ScalarType::Float) {
|
if (query.dtype() == at::ScalarType::Float) {
|
||||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
||||||
} else if (query.dtype() == at::ScalarType::Half) {
|
} else if (query.dtype() == at::ScalarType::Half) {
|
||||||
@@ -748,7 +779,7 @@ void paged_attention_v1(
|
|||||||
|
|
||||||
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
||||||
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||||
IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE> \
|
IS_FP8_KV_CACHE, PARTITION_SIZE> \
|
||||||
<<<grid, block, shared_mem_size, stream>>>( \
|
<<<grid, block, shared_mem_size, stream>>>( \
|
||||||
exp_sums_ptr, \
|
exp_sums_ptr, \
|
||||||
max_logits_ptr, \
|
max_logits_ptr, \
|
||||||
@@ -764,7 +795,8 @@ void paged_attention_v1(
|
|||||||
alibi_slopes_ptr, \
|
alibi_slopes_ptr, \
|
||||||
q_stride, \
|
q_stride, \
|
||||||
kv_block_stride, \
|
kv_block_stride, \
|
||||||
kv_head_stride); \
|
kv_head_stride, \
|
||||||
|
kv_scale); \
|
||||||
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
||||||
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
||||||
out_ptr, \
|
out_ptr, \
|
||||||
@@ -778,7 +810,7 @@ template<
|
|||||||
typename T,
|
typename T,
|
||||||
typename CACHE_T,
|
typename CACHE_T,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
bool IS_FP8_E5M2_KV_CACHE,
|
bool IS_FP8_KV_CACHE,
|
||||||
int NUM_THREADS = 128,
|
int NUM_THREADS = 128,
|
||||||
int PARTITION_SIZE = 512>
|
int PARTITION_SIZE = 512>
|
||||||
void paged_attention_v2_launcher(
|
void paged_attention_v2_launcher(
|
||||||
@@ -794,7 +826,8 @@ void paged_attention_v2_launcher(
|
|||||||
torch::Tensor& block_tables,
|
torch::Tensor& block_tables,
|
||||||
torch::Tensor& context_lens,
|
torch::Tensor& context_lens,
|
||||||
int max_context_len,
|
int max_context_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
|
float kv_scale) {
|
||||||
int num_seqs = query.size(0);
|
int num_seqs = query.size(0);
|
||||||
int num_heads = query.size(1);
|
int num_heads = query.size(1);
|
||||||
int head_size = query.size(2);
|
int head_size = query.size(2);
|
||||||
@@ -864,8 +897,8 @@ void paged_attention_v2_launcher(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
|
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
|
||||||
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
|
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE>( \
|
||||||
out, \
|
out, \
|
||||||
exp_sums, \
|
exp_sums, \
|
||||||
max_logits, \
|
max_logits, \
|
||||||
@@ -878,20 +911,21 @@ void paged_attention_v2_launcher(
|
|||||||
block_tables, \
|
block_tables, \
|
||||||
context_lens, \
|
context_lens, \
|
||||||
max_context_len, \
|
max_context_len, \
|
||||||
alibi_slopes);
|
alibi_slopes, \
|
||||||
|
kv_scale);
|
||||||
|
|
||||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||||
// 1, 2, 4, 64, 128, 256.
|
// 1, 2, 4, 64, 128, 256.
|
||||||
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \
|
||||||
switch (block_size) { \
|
switch (block_size) { \
|
||||||
case 8: \
|
case 8: \
|
||||||
CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
|
CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
case 16: \
|
case 16: \
|
||||||
CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
|
CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
case 32: \
|
case 32: \
|
||||||
CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
|
CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
default: \
|
default: \
|
||||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||||
@@ -913,7 +947,8 @@ void paged_attention_v2(
|
|||||||
int block_size,
|
int block_size,
|
||||||
int max_context_len,
|
int max_context_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype) {
|
const std::string& kv_cache_dtype,
|
||||||
|
float kv_scale) {
|
||||||
if (kv_cache_dtype == "auto") {
|
if (kv_cache_dtype == "auto") {
|
||||||
if (query.dtype() == at::ScalarType::Float) {
|
if (query.dtype() == at::ScalarType::Float) {
|
||||||
CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
|
CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
|
||||||
@@ -924,7 +959,7 @@ void paged_attention_v2(
|
|||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||||
}
|
}
|
||||||
} else if (kv_cache_dtype == "fp8_e5m2") {
|
} else if (kv_cache_dtype == "fp8") {
|
||||||
if (query.dtype() == at::ScalarType::Float) {
|
if (query.dtype() == at::ScalarType::Float) {
|
||||||
CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
||||||
} else if (query.dtype() == at::ScalarType::Half) {
|
} else if (query.dtype() == at::ScalarType::Half) {
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
#ifdef ENABLE_FP8_E5M2
|
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
|
||||||
// fp8 vector types for quantization of kv cache
|
// fp8 vector types for quantization of kv cache
|
||||||
|
|
||||||
template<>
|
template<>
|
||||||
@@ -21,9 +21,10 @@ void reshape_and_cache(
|
|||||||
torch::Tensor& key_cache,
|
torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache,
|
torch::Tensor& value_cache,
|
||||||
torch::Tensor& slot_mapping,
|
torch::Tensor& slot_mapping,
|
||||||
const std::string& kv_cache_dtype);
|
const std::string& kv_cache_dtype,
|
||||||
|
const float kv_scale);
|
||||||
|
|
||||||
// Just for unittest
|
// Just for unittest
|
||||||
void convert_fp8_e5m2(
|
void convert_fp8(
|
||||||
torch::Tensor& src_cache,
|
torch::Tensor& src_cache,
|
||||||
torch::Tensor& dst_cache);
|
torch::Tensor& dst_cache);
|
||||||
|
|||||||
@@ -4,8 +4,10 @@
|
|||||||
|
|
||||||
#include "cuda_compat.h"
|
#include "cuda_compat.h"
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
#ifdef ENABLE_FP8_E5M2
|
#if defined(ENABLE_FP8_E5M2)
|
||||||
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
||||||
|
#elif defined(ENABLE_FP8_E4M3)
|
||||||
|
#include "quantization/fp8/amd_detail/quant_utils.cuh"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
@@ -151,7 +153,7 @@ void copy_blocks(
|
|||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
|
template<typename scalar_t, typename cache_t, bool is_fp8_kv_cache>
|
||||||
__global__ void reshape_and_cache_kernel(
|
__global__ void reshape_and_cache_kernel(
|
||||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||||
@@ -163,7 +165,8 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
const int num_heads,
|
const int num_heads,
|
||||||
const int head_size,
|
const int head_size,
|
||||||
const int block_size,
|
const int block_size,
|
||||||
const int x) {
|
const int x,
|
||||||
|
const float kv_scale) {
|
||||||
const int64_t token_idx = blockIdx.x;
|
const int64_t token_idx = blockIdx.x;
|
||||||
const int64_t slot_idx = slot_mapping[token_idx];
|
const int64_t slot_idx = slot_mapping[token_idx];
|
||||||
if (slot_idx < 0) {
|
if (slot_idx < 0) {
|
||||||
@@ -195,10 +198,13 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
+ block_offset;
|
+ block_offset;
|
||||||
scalar_t tgt_key = key[src_key_idx];
|
scalar_t tgt_key = key[src_key_idx];
|
||||||
scalar_t tgt_value = value[src_value_idx];
|
scalar_t tgt_value = value[src_value_idx];
|
||||||
if constexpr (is_fp8_e5m2_kv_cache) {
|
if constexpr (is_fp8_kv_cache) {
|
||||||
#ifdef ENABLE_FP8_E5M2
|
#if defined(ENABLE_FP8_E5M2)
|
||||||
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
|
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
|
||||||
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
|
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
|
||||||
|
#elif defined(ENABLE_FP8_E4M3)
|
||||||
|
key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_key, kv_scale);
|
||||||
|
value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_value, kv_scale);
|
||||||
#else
|
#else
|
||||||
assert(false);
|
assert(false);
|
||||||
#endif
|
#endif
|
||||||
@@ -211,8 +217,8 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
|
||||||
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
|
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_KV_CACHE><<<grid, block, 0, stream>>>( \
|
||||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||||
@@ -223,7 +229,8 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
num_heads, \
|
num_heads, \
|
||||||
head_size, \
|
head_size, \
|
||||||
block_size, \
|
block_size, \
|
||||||
x);
|
x, \
|
||||||
|
kv_scale);
|
||||||
|
|
||||||
void reshape_and_cache(
|
void reshape_and_cache(
|
||||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||||
@@ -231,7 +238,8 @@ void reshape_and_cache(
|
|||||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
torch::Tensor& slot_mapping, // [num_tokens]
|
torch::Tensor& slot_mapping, // [num_tokens]
|
||||||
const std::string& kv_cache_dtype)
|
const std::string& kv_cache_dtype,
|
||||||
|
const float kv_scale)
|
||||||
{
|
{
|
||||||
int num_tokens = key.size(0);
|
int num_tokens = key.size(0);
|
||||||
int num_heads = key.size(1);
|
int num_heads = key.size(1);
|
||||||
@@ -254,7 +262,7 @@ void reshape_and_cache(
|
|||||||
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
||||||
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
|
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
|
||||||
}
|
}
|
||||||
} else if (kv_cache_dtype == "fp8_e5m2") {
|
} else if (kv_cache_dtype == "fp8") {
|
||||||
if (key.dtype() == at::ScalarType::Float) {
|
if (key.dtype() == at::ScalarType::Float) {
|
||||||
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
|
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
|
||||||
} else if (key.dtype() == at::ScalarType::Half) {
|
} else if (key.dtype() == at::ScalarType::Half) {
|
||||||
@@ -270,15 +278,17 @@ void reshape_and_cache(
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template<typename Tout, typename Tin>
|
template<typename Tout, typename Tin>
|
||||||
__global__ void convert_fp8_e5m2_kernel(
|
__global__ void convert_fp8_kernel(
|
||||||
const Tin* __restrict__ src_cache,
|
const Tin* __restrict__ src_cache,
|
||||||
Tout* __restrict__ dst_cache,
|
Tout* __restrict__ dst_cache,
|
||||||
const int64_t block_stride) {
|
const int64_t block_stride) {
|
||||||
const int64_t block_idx = blockIdx.x;
|
const int64_t block_idx = blockIdx.x;
|
||||||
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
||||||
int64_t idx = block_idx * block_stride + i;
|
int64_t idx = block_idx * block_stride + i;
|
||||||
#ifdef ENABLE_FP8_E5M2
|
#if defined(ENABLE_FP8_E5M2)
|
||||||
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
|
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
|
||||||
|
#elif defined(ENABLE_FP8_E4M3)
|
||||||
|
dst_cache[idx] = fp8_e4m3::vec_conversion<Tout, Tin>(src_cache[idx]);
|
||||||
#else
|
#else
|
||||||
assert(false);
|
assert(false);
|
||||||
#endif
|
#endif
|
||||||
@@ -287,16 +297,25 @@ __global__ void convert_fp8_e5m2_kernel(
|
|||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
#define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
|
#define CALL_CONVERT_FP8(Tout, Tin) \
|
||||||
vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
|
vllm::convert_fp8_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
|
||||||
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
||||||
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
|
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
|
||||||
block_stride);
|
block_stride);
|
||||||
|
|
||||||
void convert_fp8_e5m2(
|
void convert_fp8(
|
||||||
torch::Tensor& src_cache,
|
torch::Tensor& src_cache,
|
||||||
torch::Tensor& dst_cache)
|
torch::Tensor& dst_cache)
|
||||||
{
|
{
|
||||||
|
torch::Device src_device = src_cache.device();
|
||||||
|
torch::Device dst_device = dst_cache.device();
|
||||||
|
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
|
||||||
|
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
|
||||||
|
TORCH_CHECK(
|
||||||
|
src_device.index() == dst_device.index(),
|
||||||
|
"src and dst must be on the same GPU");
|
||||||
|
at::cuda::OptionalCUDAGuard device_guard(src_device);
|
||||||
|
|
||||||
int64_t num_blocks = src_cache.size(0);
|
int64_t num_blocks = src_cache.size(0);
|
||||||
int64_t block_stride = src_cache.stride(0);
|
int64_t block_stride = src_cache.stride(0);
|
||||||
|
|
||||||
@@ -305,16 +324,16 @@ void convert_fp8_e5m2(
|
|||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
if (src_cache.dtype() == at::ScalarType::Float) {
|
if (src_cache.dtype() == at::ScalarType::Float) {
|
||||||
CALL_CONVERT_FP8_E5M2(uint8_t, float);
|
CALL_CONVERT_FP8(uint8_t, float);
|
||||||
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
||||||
CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
|
CALL_CONVERT_FP8(uint8_t, uint16_t);
|
||||||
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
||||||
CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
|
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16);
|
||||||
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
||||||
CALL_CONVERT_FP8_E5M2(float, uint8_t);
|
CALL_CONVERT_FP8(float, uint8_t);
|
||||||
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
||||||
CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
|
CALL_CONVERT_FP8(uint16_t, uint8_t);
|
||||||
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
||||||
CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
|
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
148
csrc/cpu/activation.cpp
Normal file
148
csrc/cpu/activation.cpp
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
#include "cpu_types.hpp"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8 &),
|
||||||
|
bool is_gated>
|
||||||
|
void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input,
|
||||||
|
scalar_t *__restrict__ output) {
|
||||||
|
using scalar_vec_t = vec_op::vec_t<scalar_t>;
|
||||||
|
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
||||||
|
|
||||||
|
TORCH_CHECK(d % VEC_ELEM_NUM == 0);
|
||||||
|
|
||||||
|
#pragma omp parallel for
|
||||||
|
for (int i = 0; i < num_tokens; ++i) {
|
||||||
|
for (int j = 0; j < d; j += VEC_ELEM_NUM) {
|
||||||
|
int start = i * d;
|
||||||
|
if constexpr (is_gated) {
|
||||||
|
start *= 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
const scalar_vec_t x(input + start + j);
|
||||||
|
const vec_op::FP32Vec8 f32_x(x);
|
||||||
|
vec_op::FP32Vec8 f32_ans = func(f32_x);
|
||||||
|
|
||||||
|
if constexpr (is_gated) {
|
||||||
|
const scalar_vec_t y(input + start + d + j);
|
||||||
|
const vec_op::FP32Vec8 f32_y(y);
|
||||||
|
f32_ans = f32_y * f32_ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
const scalar_vec_t result(f32_ans);
|
||||||
|
result.save(output + i * d + j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) {
|
||||||
|
const vec_op::FP32Vec8 zeros(0.0);
|
||||||
|
const vec_op::FP32Vec8 ones(1.0);
|
||||||
|
return x / (ones + (zeros - x).exp());
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) {
|
||||||
|
const vec_op::FP32Vec8 ones(1.0);
|
||||||
|
const vec_op::FP32Vec8 w1(0.79788456f);
|
||||||
|
const vec_op::FP32Vec8 w2(0.044715f);
|
||||||
|
const vec_op::FP32Vec8 w3(0.5);
|
||||||
|
const vec_op::FP32Vec8 x3 = x * x * x;
|
||||||
|
const vec_op::FP32Vec8 t = (w1 * (x + w2 * x3)).tanh();
|
||||||
|
return w3 * x * (ones + t);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) {
|
||||||
|
const vec_op::FP32Vec8 ones(1.0);
|
||||||
|
const vec_op::FP32Vec8 w1(0.79788456f);
|
||||||
|
const vec_op::FP32Vec8 w2(0.044715f);
|
||||||
|
const vec_op::FP32Vec8 w3(0.5);
|
||||||
|
const vec_op::FP32Vec8 t = (x * w1 * (ones + x * w2 * x)).tanh();
|
||||||
|
return w3 * x * (ones + t);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) {
|
||||||
|
const vec_op::FP32Vec8 ones(1.0);
|
||||||
|
const vec_op::FP32Vec8 w1(M_SQRT1_2);
|
||||||
|
const vec_op::FP32Vec8 w2(0.5);
|
||||||
|
return x * w2 * (ones + (x * w1).er());
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) {
|
||||||
|
const vec_op::FP32Vec8 ones(1.0);
|
||||||
|
const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5);
|
||||||
|
const vec_op::FP32Vec8 w2(0.5);
|
||||||
|
const vec_op::FP32Vec8 w3(0.044715);
|
||||||
|
const vec_op::FP32Vec8 x_3 = x * x * x;
|
||||||
|
const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3);
|
||||||
|
return x * w2 * (ones + inner.tanh());
|
||||||
|
}
|
||||||
|
}; // namespace
|
||||||
|
|
||||||
|
void silu_and_mul(torch::Tensor &out, torch::Tensor &input) {
|
||||||
|
int num_tokens = input.numel() / input.size(-1);
|
||||||
|
int d = input.size(-1) / 2;
|
||||||
|
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
input.scalar_type(), "silu_and_mul_impl", [&] {
|
||||||
|
CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
|
||||||
|
activation_kernel<scalar_t, silu_act, true>(num_tokens, d,
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
out.data_ptr<scalar_t>());
|
||||||
|
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void gelu_and_mul(torch::Tensor &out, // [..., d]
|
||||||
|
torch::Tensor &input) // [..., 2 * d]
|
||||||
|
{
|
||||||
|
int num_tokens = input.numel() / input.size(-1);
|
||||||
|
int d = input.size(-1) / 2;
|
||||||
|
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
input.scalar_type(), "gelu_and_mul_impl", [&] {
|
||||||
|
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
|
||||||
|
activation_kernel<scalar_t, gelu_act, true>(num_tokens, d,
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
out.data_ptr<scalar_t>());
|
||||||
|
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void gelu_tanh_and_mul(torch::Tensor &out, // [..., d]
|
||||||
|
torch::Tensor &input) // [..., 2 * d]
|
||||||
|
{
|
||||||
|
int num_tokens = input.numel() / input.size(-1);
|
||||||
|
int d = input.size(-1) / 2;
|
||||||
|
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
input.scalar_type(), "gelu_tanh_and_mul_impl", [&] {
|
||||||
|
CPU_KERNEL_GUARD_IN(gelu_tanh_and_mul_impl)
|
||||||
|
activation_kernel<scalar_t, gelu_tanh_act, true>(
|
||||||
|
num_tokens, d, input.data_ptr<scalar_t>(),
|
||||||
|
out.data_ptr<scalar_t>());
|
||||||
|
CPU_KERNEL_GUARD_OUT(gelu_tanh_and_mul_impl)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void gelu_new(torch::Tensor &out, torch::Tensor &input) {
|
||||||
|
int num_tokens = input.numel() / input.size(-1);
|
||||||
|
int d = input.size(-1);
|
||||||
|
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_new_impl", [&] {
|
||||||
|
CPU_KERNEL_GUARD_IN(gelu_new_impl)
|
||||||
|
activation_kernel<scalar_t, gelu_new_act, false>(
|
||||||
|
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
|
||||||
|
CPU_KERNEL_GUARD_OUT(gelu_new_impl)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void gelu_fast(torch::Tensor &out, torch::Tensor &input) {
|
||||||
|
int num_tokens = input.numel() / input.size(-1);
|
||||||
|
int d = input.size(-1);
|
||||||
|
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_fast_impl", [&] {
|
||||||
|
CPU_KERNEL_GUARD_IN(gelu_fast_impl)
|
||||||
|
activation_kernel<scalar_t, gelu_fast_act, false>(
|
||||||
|
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
|
||||||
|
CPU_KERNEL_GUARD_OUT(gelu_fast_impl)
|
||||||
|
});
|
||||||
|
}
|
||||||
746
csrc/cpu/attention.cpp
Normal file
746
csrc/cpu/attention.cpp
Normal file
@@ -0,0 +1,746 @@
|
|||||||
|
#include "cpu_types.hpp"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename scalar_t> struct KernelVecType {
|
||||||
|
using q_load_vec_type = void;
|
||||||
|
using q_vec_type = void;
|
||||||
|
using k_load_vec_type = void;
|
||||||
|
using k_vec_type = void;
|
||||||
|
using qk_acc_vec_type = void;
|
||||||
|
using v_load_vec_type = void;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <> struct KernelVecType<float> {
|
||||||
|
using q_load_vec_type = vec_op::FP32Vec4;
|
||||||
|
using q_vec_type = vec_op::FP32Vec16;
|
||||||
|
using k_load_vec_type = vec_op::FP32Vec16;
|
||||||
|
using k_vec_type = vec_op::FP32Vec16;
|
||||||
|
using qk_acc_vec_type = vec_op::FP32Vec16;
|
||||||
|
using v_load_vec_type = vec_op::FP32Vec16;
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifdef __AVX512BF16__
|
||||||
|
template <> struct KernelVecType<c10::BFloat16> {
|
||||||
|
using q_load_vec_type = vec_op::BF16Vec8;
|
||||||
|
using q_vec_type = vec_op::BF16Vec32;
|
||||||
|
using k_load_vec_type = vec_op::BF16Vec32;
|
||||||
|
using k_vec_type = vec_op::BF16Vec32;
|
||||||
|
using qk_acc_vec_type = vec_op::FP32Vec16;
|
||||||
|
using v_load_vec_type = vec_op::BF16Vec16;
|
||||||
|
};
|
||||||
|
#else
|
||||||
|
template <> struct KernelVecType<c10::BFloat16> {
|
||||||
|
using q_load_vec_type = vec_op::BF16Vec8;
|
||||||
|
using q_vec_type = vec_op::FP32Vec16;
|
||||||
|
using k_load_vec_type = vec_op::BF16Vec16;
|
||||||
|
using k_vec_type = vec_op::FP32Vec16;
|
||||||
|
using qk_acc_vec_type = vec_op::FP32Vec16;
|
||||||
|
using v_load_vec_type = vec_op::BF16Vec16;
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size,
|
||||||
|
const int capacity) {
|
||||||
|
T max = data[0];
|
||||||
|
for (int i = 1; i < size; ++i) {
|
||||||
|
max = max >= data[i] ? max : data[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
T sum = 0;
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
data[i] = std::exp(data[i] - max);
|
||||||
|
sum += data[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
int i = 0;
|
||||||
|
for (; i < size; ++i) {
|
||||||
|
data[i] /= sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (; i < capacity; ++i) {
|
||||||
|
data[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {max, sum};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FORCE_INLINE std::pair<T, T>
|
||||||
|
reduceSoftmaxAlibi(T *data, const int size, const int capacity,
|
||||||
|
const float alibi_slope, const int start_index,
|
||||||
|
const int context_len) {
|
||||||
|
data[0] += alibi_slope * (start_index - context_len + 1);
|
||||||
|
T max = data[0];
|
||||||
|
for (int i = 1; i < size; ++i) {
|
||||||
|
T qk = data[i] + alibi_slope * (start_index + i - context_len + 1);
|
||||||
|
data[i] = qk;
|
||||||
|
max = max >= qk ? max : qk;
|
||||||
|
}
|
||||||
|
|
||||||
|
T sum = 0;
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
data[i] = std::exp(data[i] - max);
|
||||||
|
sum += data[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
int i = 0;
|
||||||
|
for (; i < size; ++i) {
|
||||||
|
data[i] /= sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (; i < capacity; ++i) {
|
||||||
|
data[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {max, sum};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data,
|
||||||
|
const int size) {
|
||||||
|
T max = max_data[0];
|
||||||
|
for (int i = 1; i < size; ++i) {
|
||||||
|
max = max >= max_data[i] ? max : max_data[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
T rescaled_sum = 0;
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
T rescale_factor = std::exp(max_data[i] - max);
|
||||||
|
rescaled_sum += rescale_factor * sum_data[i];
|
||||||
|
sum_data[i] *= rescale_factor;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
sum_data[i] /= rescaled_sum + 1e-8;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int x>
|
||||||
|
struct reduceQKBlockKernel {
|
||||||
|
using q_load_vec_type = typename KernelVecType<scalar_t>::q_load_vec_type;
|
||||||
|
using q_vec_type = typename KernelVecType<scalar_t>::q_vec_type;
|
||||||
|
using k_load_vec_type = typename KernelVecType<scalar_t>::k_load_vec_type;
|
||||||
|
using k_vec_type = typename KernelVecType<scalar_t>::k_vec_type;
|
||||||
|
using qk_acc_vec_type = typename KernelVecType<scalar_t>::qk_acc_vec_type;
|
||||||
|
|
||||||
|
constexpr static int TOKEN_PER_GROUP = k_load_vec_type::get_elem_num() / x;
|
||||||
|
constexpr static int MAX_GROUP_NUM = 16 / TOKEN_PER_GROUP;
|
||||||
|
constexpr static int UNROLL_GROUP_NUM = MAX_GROUP_NUM / 4;
|
||||||
|
|
||||||
|
static_assert(MAX_GROUP_NUM == 8 || MAX_GROUP_NUM == 4);
|
||||||
|
static_assert(k_load_vec_type::get_elem_num() % x == 0);
|
||||||
|
static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16);
|
||||||
|
|
||||||
|
FORCE_INLINE static void call(const scalar_t *__restrict__ q,
|
||||||
|
const scalar_t *__restrict__ k_block,
|
||||||
|
float *__restrict__ logits, float scale,
|
||||||
|
const int token_num) {
|
||||||
|
const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP;
|
||||||
|
|
||||||
|
qk_acc_vec_type group_accums[MAX_GROUP_NUM];
|
||||||
|
if (token_num == BLOCK_SIZE) {
|
||||||
|
for (int q_offset = 0; q_offset < HEAD_SIZE;
|
||||||
|
q_offset += x, k_block += x * BLOCK_SIZE) {
|
||||||
|
q_load_vec_type q_load_group_vec(q + q_offset);
|
||||||
|
q_vec_type q_group_vec(q_load_group_vec);
|
||||||
|
|
||||||
|
vec_op::unroll_loop<int, MAX_GROUP_NUM>(
|
||||||
|
[k_block, &q_group_vec, &group_accums](int token_group_idx) {
|
||||||
|
k_load_vec_type k_load_group_vec(k_block + token_group_idx * x *
|
||||||
|
TOKEN_PER_GROUP);
|
||||||
|
k_vec_type k_group_vec(k_load_group_vec);
|
||||||
|
vec_op::fma(group_accums[token_group_idx], q_group_vec,
|
||||||
|
k_group_vec);
|
||||||
|
vec_op::prefetch(k_block + x * BLOCK_SIZE +
|
||||||
|
token_group_idx * x * TOKEN_PER_GROUP);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int q_offset = 0; q_offset < HEAD_SIZE;
|
||||||
|
q_offset += x, k_block += x * BLOCK_SIZE) {
|
||||||
|
q_load_vec_type q_load_group_vec(q + q_offset);
|
||||||
|
q_vec_type q_group_vec(q_load_group_vec);
|
||||||
|
for (int token_group_start = 0; token_group_start < group_num;
|
||||||
|
token_group_start += UNROLL_GROUP_NUM) {
|
||||||
|
vec_op::unroll_loop<int, UNROLL_GROUP_NUM>(
|
||||||
|
[token_group_start, k_block, &q_group_vec,
|
||||||
|
&group_accums](int token_group_idx) {
|
||||||
|
token_group_idx += token_group_start;
|
||||||
|
k_load_vec_type k_load_group_vec(k_block + token_group_idx * x *
|
||||||
|
TOKEN_PER_GROUP);
|
||||||
|
k_vec_type k_group_vec(k_load_group_vec);
|
||||||
|
vec_op::fma(group_accums[token_group_idx], q_group_vec,
|
||||||
|
k_group_vec);
|
||||||
|
vec_op::prefetch(k_block + x * BLOCK_SIZE +
|
||||||
|
token_group_idx * x * TOKEN_PER_GROUP);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int token_group_idx = 0; token_group_idx < group_num;
|
||||||
|
++token_group_idx) {
|
||||||
|
vec_op::unroll_loop<int, TOKEN_PER_GROUP>(
|
||||||
|
[&group_accums, logits, scale, token_group_idx](int token_idx) {
|
||||||
|
float dot_v =
|
||||||
|
group_accums[token_group_idx]
|
||||||
|
.template reduce_sub_sum<qk_acc_vec_type::get_elem_num() /
|
||||||
|
TOKEN_PER_GROUP>(token_idx);
|
||||||
|
logits[token_group_idx * TOKEN_PER_GROUP + token_idx] =
|
||||||
|
dot_v * scale;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE,
|
||||||
|
int HEAD_PARTITION_SIZE, typename acc_t>
|
||||||
|
FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block,
|
||||||
|
acc_t &&acc) {
|
||||||
|
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
||||||
|
constexpr int ELEM_NUM = v_load_vec_type::get_elem_num();
|
||||||
|
static_assert(BLOCK_SIZE == ELEM_NUM);
|
||||||
|
vec_op::FP32Vec16 prob_vec(prob);
|
||||||
|
|
||||||
|
vec_op::unroll_loop<int, HEAD_PARTITION_SIZE>([&](int head_elem_idx) {
|
||||||
|
v_load_vec_type v_vec(v_block + BLOCK_SIZE * head_elem_idx);
|
||||||
|
vec_op::FP32Vec16 fp32_v_vec(v_vec);
|
||||||
|
acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}; // namespace
|
||||||
|
|
||||||
|
// Paged attention v1
|
||||||
|
namespace {
|
||||||
|
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
|
||||||
|
struct paged_attention_v1_impl {
|
||||||
|
static void
|
||||||
|
call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
|
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
|
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||||
|
// head_size/x, block_size, x]
|
||||||
|
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||||
|
// head_size, block_size]
|
||||||
|
const int num_kv_heads, const float scale,
|
||||||
|
const int
|
||||||
|
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
|
const int *__restrict__ context_lens, // [num_seqs]
|
||||||
|
const int max_num_blocks_per_seq,
|
||||||
|
const float *__restrict__ alibi_slopes, // [num_heads]
|
||||||
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||||
|
const int num_seqs, const int num_heads) {
|
||||||
|
constexpr int x = 16 / sizeof(scalar_t);
|
||||||
|
const int num_queries_per_kv = num_heads / num_kv_heads;
|
||||||
|
|
||||||
|
static_assert(BLOCK_SIZE == 16);
|
||||||
|
|
||||||
|
int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE;
|
||||||
|
int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0;
|
||||||
|
TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0);
|
||||||
|
|
||||||
|
const int parallel_work_item_num = omp_get_max_threads();
|
||||||
|
|
||||||
|
size_t logits_bytes =
|
||||||
|
parallel_work_item_num * max_context_len_padded * sizeof(float);
|
||||||
|
float *logits = (float *)std::aligned_alloc(
|
||||||
|
64, logits_bytes); // Cacheline alignment for each context token.
|
||||||
|
// [parallel_work_item_num, max_context_len_padded]
|
||||||
|
|
||||||
|
#pragma omp parallel for collapse(2) schedule(dynamic, 1)
|
||||||
|
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||||
|
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||||
|
int context_len = context_lens[seq_idx];
|
||||||
|
const int *seq_block_table =
|
||||||
|
block_tables + max_num_blocks_per_seq * seq_idx;
|
||||||
|
const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
|
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
|
||||||
|
const scalar_t *__restrict__ q_vec_ptr =
|
||||||
|
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||||
|
const int last_block_token_num =
|
||||||
|
context_len - (block_num - 1) * BLOCK_SIZE;
|
||||||
|
float *__restrict__ thread_block_logits =
|
||||||
|
logits + omp_get_thread_num() * max_context_len_padded;
|
||||||
|
|
||||||
|
// Compute logits
|
||||||
|
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||||
|
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||||
|
const scalar_t *__restrict__ k_block_cache_ptr =
|
||||||
|
k_cache + physical_block_idx * kv_block_stride +
|
||||||
|
kv_head_idx * kv_head_stride;
|
||||||
|
float *__restrict__ head_block_logits =
|
||||||
|
thread_block_logits + block_idx * BLOCK_SIZE;
|
||||||
|
|
||||||
|
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
|
||||||
|
q_vec_ptr, k_block_cache_ptr, head_block_logits, scale,
|
||||||
|
block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute softmax
|
||||||
|
if (alibi_slopes) {
|
||||||
|
reduceSoftmaxAlibi(thread_block_logits, context_len,
|
||||||
|
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
|
||||||
|
context_len);
|
||||||
|
} else {
|
||||||
|
reduceSoftmax(thread_block_logits, context_len,
|
||||||
|
block_num * BLOCK_SIZE);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute value
|
||||||
|
constexpr int head_elem_num_per_partition = 16;
|
||||||
|
constexpr int head_partition_num =
|
||||||
|
HEAD_SIZE / head_elem_num_per_partition;
|
||||||
|
for (int head_part_idx = 0; head_part_idx < head_partition_num;
|
||||||
|
++head_part_idx) {
|
||||||
|
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
|
||||||
|
scalar_t *__restrict__ out_ptr =
|
||||||
|
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
|
||||||
|
head_part_idx * head_elem_num_per_partition;
|
||||||
|
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||||
|
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||||
|
const float *__restrict__ prob_vec_ptr =
|
||||||
|
thread_block_logits + block_idx * BLOCK_SIZE;
|
||||||
|
const scalar_t *__restrict__ v_block_cache_ptr =
|
||||||
|
v_cache + physical_block_idx * kv_block_stride +
|
||||||
|
kv_head_idx * kv_head_stride +
|
||||||
|
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||||
|
reduceValueBlock<scalar_t, HEAD_SIZE, BLOCK_SIZE,
|
||||||
|
head_elem_num_per_partition>(
|
||||||
|
prob_vec_ptr, v_block_cache_ptr, accums);
|
||||||
|
|
||||||
|
if (block_idx != block_num - 1) {
|
||||||
|
const int64_t next_physical_block_idx =
|
||||||
|
seq_block_table[block_idx + 1];
|
||||||
|
const scalar_t *__restrict__ next_v_block_cache_ptr =
|
||||||
|
v_cache + next_physical_block_idx * kv_block_stride +
|
||||||
|
kv_head_idx * kv_head_stride +
|
||||||
|
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||||
|
vec_op::unroll_loop<int, head_elem_num_per_partition>(
|
||||||
|
[&](int head_elem_idx) {
|
||||||
|
if (head_elem_idx % 2 == 0) {
|
||||||
|
vec_op::prefetch(next_v_block_cache_ptr +
|
||||||
|
BLOCK_SIZE * head_elem_idx);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
vec_op::unroll_loop<int, head_elem_num_per_partition>(
|
||||||
|
[&](int head_elem_idx) {
|
||||||
|
float value = accums[head_elem_idx].reduce_sum();
|
||||||
|
vec_op::storeFP32(value, out_ptr + head_elem_idx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::free(logits);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
|
||||||
|
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
|
||||||
|
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||||
|
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
|
||||||
|
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
|
||||||
|
num_heads);
|
||||||
|
|
||||||
|
template <typename T, int BLOCK_SIZE>
|
||||||
|
void paged_attention_v1_impl_launcher(
|
||||||
|
torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache,
|
||||||
|
torch::Tensor &value_cache, int num_kv_heads, float scale,
|
||||||
|
torch::Tensor &block_tables, torch::Tensor &context_lens,
|
||||||
|
int max_context_len, const c10::optional<torch::Tensor> &alibi_slopes) {
|
||||||
|
int num_seqs = query.size(0);
|
||||||
|
int num_heads = query.size(1);
|
||||||
|
int head_size = query.size(2);
|
||||||
|
int max_num_blocks_per_seq = block_tables.size(1);
|
||||||
|
int q_stride = query.stride(0);
|
||||||
|
int kv_block_stride = key_cache.stride(0);
|
||||||
|
int kv_head_stride = key_cache.stride(1);
|
||||||
|
|
||||||
|
// NOTE: alibi_slopes is optional.
|
||||||
|
const float *alibi_slopes_ptr =
|
||||||
|
alibi_slopes
|
||||||
|
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr())
|
||||||
|
: nullptr;
|
||||||
|
|
||||||
|
T *out_ptr = reinterpret_cast<T *>(out.data_ptr());
|
||||||
|
T *query_ptr = reinterpret_cast<T *>(query.data_ptr());
|
||||||
|
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
|
||||||
|
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
|
||||||
|
int *block_tables_ptr = block_tables.data_ptr<int>();
|
||||||
|
int *context_lens_ptr = context_lens.data_ptr<int>();
|
||||||
|
|
||||||
|
switch (head_size) {
|
||||||
|
case 64:
|
||||||
|
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
|
||||||
|
break;
|
||||||
|
case 80:
|
||||||
|
LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
|
||||||
|
break;
|
||||||
|
case 96:
|
||||||
|
LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
|
||||||
|
break;
|
||||||
|
case 112:
|
||||||
|
LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
||||||
|
break;
|
||||||
|
case 256:
|
||||||
|
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
||||||
|
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
|
||||||
|
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
||||||
|
context_lens, max_context_len, alibi_slopes);
|
||||||
|
|
||||||
|
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||||
|
switch (block_size) { \
|
||||||
|
case 16: \
|
||||||
|
CALL_V1_KERNEL_LAUNCHER(T, 16); \
|
||||||
|
break; \
|
||||||
|
default: \
|
||||||
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||||
|
break; \
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
|
||||||
|
torch::Tensor &key_cache, torch::Tensor &value_cache,
|
||||||
|
int num_kv_heads, float scale,
|
||||||
|
torch::Tensor &block_tables,
|
||||||
|
torch::Tensor &context_lens, int block_size,
|
||||||
|
int max_context_len,
|
||||||
|
const c10::optional<torch::Tensor> &alibi_slopes,
|
||||||
|
const std::string &kv_cache_dtype, float kv_scale) {
|
||||||
|
TORCH_CHECK(kv_scale == 1.0f);
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
||||||
|
[&] {
|
||||||
|
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
|
||||||
|
CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
|
||||||
|
CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Paged attention v2
|
||||||
|
namespace {
|
||||||
|
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
|
||||||
|
struct paged_attention_v2_impl {
|
||||||
|
static void call(
|
||||||
|
scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
|
float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
|
float
|
||||||
|
*__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||||
|
scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads,
|
||||||
|
// max_num_partitions, head_size]
|
||||||
|
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
|
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||||
|
// head_size/x, block_size, x]
|
||||||
|
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||||
|
// head_size, block_size]
|
||||||
|
const int num_kv_heads, const float scale,
|
||||||
|
const int
|
||||||
|
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
|
const int *__restrict__ context_lens, // [num_seqs]
|
||||||
|
const int max_num_blocks_per_seq,
|
||||||
|
const float *__restrict__ alibi_slopes, // [num_heads]
|
||||||
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||||
|
const int num_seqs, const int num_heads, const int max_num_partitions) {
|
||||||
|
constexpr int x = 16 / sizeof(scalar_t);
|
||||||
|
const int num_queries_per_kv = num_heads / num_kv_heads;
|
||||||
|
|
||||||
|
static_assert(BLOCK_SIZE == 16);
|
||||||
|
static_assert(PARTITION_SIZE * sizeof(float) % 64 == 0);
|
||||||
|
static_assert(PARTITION_SIZE % BLOCK_SIZE == 0);
|
||||||
|
|
||||||
|
#pragma omp parallel for collapse(3) schedule(static, 1)
|
||||||
|
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||||
|
for (int partition_idx = 0; partition_idx < max_num_partitions;
|
||||||
|
++partition_idx) {
|
||||||
|
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||||
|
const int context_len = context_lens[seq_idx];
|
||||||
|
const int start_token_idx = partition_idx * PARTITION_SIZE;
|
||||||
|
|
||||||
|
if (start_token_idx >= context_len)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
const int partition_num =
|
||||||
|
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||||
|
const bool no_reduce = (partition_num == 1);
|
||||||
|
const int context_token_num =
|
||||||
|
(std::min(context_len, start_token_idx + PARTITION_SIZE) -
|
||||||
|
start_token_idx);
|
||||||
|
const int block_num =
|
||||||
|
(context_token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
|
const int last_block_token_num =
|
||||||
|
context_token_num - (block_num - 1) * BLOCK_SIZE;
|
||||||
|
const int *seq_block_table = block_tables +
|
||||||
|
max_num_blocks_per_seq * seq_idx +
|
||||||
|
start_token_idx / BLOCK_SIZE;
|
||||||
|
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
|
||||||
|
const scalar_t *__restrict__ q_vec_ptr =
|
||||||
|
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||||
|
|
||||||
|
float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0};
|
||||||
|
|
||||||
|
// Compute logits
|
||||||
|
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||||
|
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||||
|
const scalar_t *__restrict__ k_block_cache_ptr =
|
||||||
|
k_cache + physical_block_idx * kv_block_stride +
|
||||||
|
kv_head_idx * kv_head_stride;
|
||||||
|
float *__restrict__ head_block_logits =
|
||||||
|
logits + block_idx * BLOCK_SIZE;
|
||||||
|
|
||||||
|
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
|
||||||
|
q_vec_ptr, k_block_cache_ptr, head_block_logits, scale,
|
||||||
|
block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<float, float> max_and_sum;
|
||||||
|
if (alibi_slopes) {
|
||||||
|
max_and_sum = reduceSoftmaxAlibi(
|
||||||
|
logits, context_token_num, block_num * BLOCK_SIZE,
|
||||||
|
alibi_slopes[head_idx], start_token_idx, context_len);
|
||||||
|
} else {
|
||||||
|
max_and_sum = reduceSoftmax(logits, context_token_num,
|
||||||
|
block_num * BLOCK_SIZE);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto &&[max_logit, exp_sum] = max_and_sum;
|
||||||
|
|
||||||
|
scalar_t *__restrict__ output_buffer = nullptr;
|
||||||
|
if (!no_reduce) {
|
||||||
|
auto idx = seq_idx * num_heads * max_num_partitions +
|
||||||
|
head_idx * max_num_partitions + partition_idx;
|
||||||
|
max_logits[idx] = max_logit;
|
||||||
|
exp_sums[idx] = exp_sum;
|
||||||
|
output_buffer =
|
||||||
|
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
||||||
|
head_idx * max_num_partitions * HEAD_SIZE +
|
||||||
|
partition_idx * HEAD_SIZE;
|
||||||
|
} else {
|
||||||
|
output_buffer =
|
||||||
|
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute value
|
||||||
|
constexpr int head_elem_num_per_partition = 16;
|
||||||
|
constexpr int head_partition_num =
|
||||||
|
HEAD_SIZE / head_elem_num_per_partition;
|
||||||
|
for (int head_part_idx = 0; head_part_idx < head_partition_num;
|
||||||
|
++head_part_idx) {
|
||||||
|
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
|
||||||
|
scalar_t *__restrict__ out_ptr =
|
||||||
|
output_buffer + head_part_idx * head_elem_num_per_partition;
|
||||||
|
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||||
|
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||||
|
const float *__restrict__ prob_vec_ptr =
|
||||||
|
logits + block_idx * BLOCK_SIZE;
|
||||||
|
const scalar_t *__restrict__ v_block_cache_ptr =
|
||||||
|
v_cache + physical_block_idx * kv_block_stride +
|
||||||
|
kv_head_idx * kv_head_stride +
|
||||||
|
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||||
|
reduceValueBlock<scalar_t, HEAD_SIZE, BLOCK_SIZE,
|
||||||
|
head_elem_num_per_partition>(
|
||||||
|
prob_vec_ptr, v_block_cache_ptr, accums);
|
||||||
|
|
||||||
|
if (block_idx != block_num - 1) {
|
||||||
|
const int64_t next_physical_block_idx =
|
||||||
|
seq_block_table[block_idx + 1];
|
||||||
|
const scalar_t *__restrict__ next_v_block_cache_ptr =
|
||||||
|
v_cache + next_physical_block_idx * kv_block_stride +
|
||||||
|
kv_head_idx * kv_head_stride +
|
||||||
|
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||||
|
vec_op::unroll_loop<int, head_elem_num_per_partition>(
|
||||||
|
[&](int head_elem_idx) {
|
||||||
|
if (head_elem_idx % 2 == 0) {
|
||||||
|
vec_op::prefetch(next_v_block_cache_ptr +
|
||||||
|
BLOCK_SIZE * head_elem_idx);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
vec_op::unroll_loop<int, head_elem_num_per_partition>(
|
||||||
|
[&](int head_elem_idx) {
|
||||||
|
float value = accums[head_elem_idx].reduce_sum();
|
||||||
|
vec_op::storeFP32(value, out_ptr + head_elem_idx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rescale partition softmax and store the factors to exp_sums
|
||||||
|
#pragma omp parallel for collapse(2) schedule(static, 1)
|
||||||
|
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||||
|
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||||
|
const int context_len = context_lens[seq_idx];
|
||||||
|
const int partition_num =
|
||||||
|
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||||
|
|
||||||
|
if (partition_num == 1)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
reducePartitonSoftmax(
|
||||||
|
max_logits + seq_idx * num_heads * max_num_partitions +
|
||||||
|
head_idx * max_num_partitions,
|
||||||
|
exp_sums + seq_idx * num_heads * max_num_partitions +
|
||||||
|
head_idx * max_num_partitions,
|
||||||
|
partition_num);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reduce values
|
||||||
|
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
||||||
|
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
|
||||||
|
constexpr int head_elem_num_per_group =
|
||||||
|
16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE
|
||||||
|
// didn't align with 64 bytes
|
||||||
|
static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
|
||||||
|
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
|
||||||
|
const float *__restrict__ rescale_factors = exp_sums;
|
||||||
|
#pragma omp parallel for collapse(3) schedule(static, 1)
|
||||||
|
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||||
|
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||||
|
for (int group_idx = 0; group_idx < head_group_num; ++group_idx) {
|
||||||
|
const int context_len = context_lens[seq_idx];
|
||||||
|
const int partition_num =
|
||||||
|
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||||
|
|
||||||
|
if (partition_num == 1)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
const float *__restrict__ seq_head_rescale_factors =
|
||||||
|
rescale_factors + seq_idx * num_heads * max_num_partitions +
|
||||||
|
head_idx * max_num_partitions;
|
||||||
|
const scalar_t *__restrict__ seq_head_tmp_out =
|
||||||
|
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
||||||
|
head_idx * max_num_partitions * HEAD_SIZE +
|
||||||
|
group_idx * head_elem_num_per_group;
|
||||||
|
scalar_t *__restrict__ seq_head_output =
|
||||||
|
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
|
||||||
|
group_idx * head_elem_num_per_group;
|
||||||
|
|
||||||
|
vec_op::FP32Vec16 acc;
|
||||||
|
for (int i = 0; i < partition_num; ++i) {
|
||||||
|
vec_op::FP32Vec16 rescale_factor(seq_head_rescale_factors[i]);
|
||||||
|
v_load_vec_type value(seq_head_tmp_out + i * HEAD_SIZE);
|
||||||
|
vec_op::FP32Vec16 fp32_value(value);
|
||||||
|
acc = acc + fp32_value * rescale_factor;
|
||||||
|
}
|
||||||
|
v_load_vec_type cast_acc(acc);
|
||||||
|
cast_acc.save(seq_head_output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
|
||||||
|
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
|
||||||
|
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
|
||||||
|
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
||||||
|
context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
||||||
|
kv_block_stride, kv_head_stride, num_seqs, num_heads, \
|
||||||
|
max_num_partitions);
|
||||||
|
|
||||||
|
template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512>
|
||||||
|
void paged_attention_v2_impl_launcher(
|
||||||
|
torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits,
|
||||||
|
torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache,
|
||||||
|
torch::Tensor &value_cache, int num_kv_heads, float scale,
|
||||||
|
torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size,
|
||||||
|
int max_context_len, const c10::optional<torch::Tensor> &alibi_slopes) {
|
||||||
|
int num_seqs = query.size(0);
|
||||||
|
int num_heads = query.size(1);
|
||||||
|
int head_size = query.size(2);
|
||||||
|
int max_num_blocks_per_seq = block_tables.size(1);
|
||||||
|
int q_stride = query.stride(0);
|
||||||
|
int kv_block_stride = key_cache.stride(0);
|
||||||
|
int kv_head_stride = key_cache.stride(1);
|
||||||
|
int max_num_partitions = exp_sums.size(-1);
|
||||||
|
|
||||||
|
// NOTE: alibi_slopes is optional.
|
||||||
|
const float *alibi_slopes_ptr =
|
||||||
|
alibi_slopes
|
||||||
|
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr())
|
||||||
|
: nullptr;
|
||||||
|
|
||||||
|
T *out_ptr = reinterpret_cast<T *>(out.data_ptr());
|
||||||
|
float *exp_sums_ptr = reinterpret_cast<float *>(exp_sums.data_ptr());
|
||||||
|
float *max_logits_ptr = reinterpret_cast<float *>(max_logits.data_ptr());
|
||||||
|
T *tmp_out_ptr = reinterpret_cast<T *>(tmp_out.data_ptr());
|
||||||
|
T *query_ptr = reinterpret_cast<T *>(query.data_ptr());
|
||||||
|
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
|
||||||
|
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
|
||||||
|
int *block_tables_ptr = block_tables.data_ptr<int>();
|
||||||
|
int *context_lens_ptr = context_lens.data_ptr<int>();
|
||||||
|
|
||||||
|
switch (head_size) {
|
||||||
|
case 64:
|
||||||
|
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
|
||||||
|
break;
|
||||||
|
case 80:
|
||||||
|
LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
|
||||||
|
break;
|
||||||
|
case 96:
|
||||||
|
LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
|
||||||
|
break;
|
||||||
|
case 112:
|
||||||
|
LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
||||||
|
break;
|
||||||
|
case 256:
|
||||||
|
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
||||||
|
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
|
||||||
|
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||||
|
num_kv_heads, scale, block_tables, context_lens, block_size, \
|
||||||
|
max_context_len, alibi_slopes);
|
||||||
|
|
||||||
|
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||||
|
switch (block_size) { \
|
||||||
|
case 16: \
|
||||||
|
CALL_V2_KERNEL_LAUNCHER(T, 16); \
|
||||||
|
break; \
|
||||||
|
default: \
|
||||||
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||||
|
break; \
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
|
||||||
|
torch::Tensor &max_logits, torch::Tensor &tmp_out,
|
||||||
|
torch::Tensor &query, torch::Tensor &key_cache,
|
||||||
|
torch::Tensor &value_cache, int num_kv_heads,
|
||||||
|
float scale, torch::Tensor &block_tables,
|
||||||
|
torch::Tensor &context_lens, int block_size,
|
||||||
|
int max_context_len,
|
||||||
|
const c10::optional<torch::Tensor> &alibi_slopes,
|
||||||
|
const std::string &kv_cache_dtype, float kv_scale) {
|
||||||
|
TORCH_CHECK(kv_scale == 1.0f);
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
||||||
|
[&] {
|
||||||
|
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
|
||||||
|
CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
|
||||||
|
CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl)
|
||||||
|
});
|
||||||
|
}
|
||||||
141
csrc/cpu/cache.cpp
Normal file
141
csrc/cpu/cache.cpp
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "cpu_types.hpp"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename scalar_t>
|
||||||
|
void copy_blocks_cpu_impl(
|
||||||
|
std::vector<torch::Tensor> &key_caches,
|
||||||
|
std::vector<torch::Tensor> &value_caches,
|
||||||
|
const std::vector<std::pair<int64_t, int64_t>> mapping_pairs,
|
||||||
|
const int element_num_per_block, const int layer_num) {
|
||||||
|
const size_t pair_num = mapping_pairs.size();
|
||||||
|
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
|
||||||
|
#pragma omp parallel for collapse(2)
|
||||||
|
for (int layer = 0; layer < layer_num; ++layer) {
|
||||||
|
for (size_t pair = 0; pair < pair_num; ++pair) {
|
||||||
|
int64_t source_offset = element_num_per_block * mapping_pairs[pair].first;
|
||||||
|
int64_t target_offset =
|
||||||
|
element_num_per_block * mapping_pairs[pair].second;
|
||||||
|
scalar_t *key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
|
||||||
|
scalar_t *source_ptr = key_cache_ptr + source_offset;
|
||||||
|
scalar_t *target_ptr = key_cache_ptr + target_offset;
|
||||||
|
std::memcpy(target_ptr, source_ptr, block_bytes);
|
||||||
|
|
||||||
|
scalar_t *value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
|
||||||
|
source_ptr = value_cache_ptr + source_offset;
|
||||||
|
target_ptr = value_cache_ptr + target_offset;
|
||||||
|
std::memcpy(target_ptr, source_ptr, block_bytes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
void reshape_and_cache_cpu_impl(
|
||||||
|
const scalar_t *__restrict__ key, const scalar_t *__restrict__ value,
|
||||||
|
scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache,
|
||||||
|
const int64_t *__restrict__ slot_mapping, const int num_tokens,
|
||||||
|
const int key_stride, const int value_stride, const int num_heads,
|
||||||
|
const int head_size, const int block_size, const int x) {
|
||||||
|
const int block_elem_num = num_heads * head_size * block_size;
|
||||||
|
|
||||||
|
#pragma omp parallel for collapse(2)
|
||||||
|
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
|
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||||
|
const int64_t slot_idx = slot_mapping[token_idx];
|
||||||
|
if (slot_idx >= 0) {
|
||||||
|
int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
|
||||||
|
int src_value_head_idx =
|
||||||
|
token_idx * value_stride + head_idx * head_size;
|
||||||
|
const scalar_t *src_key_head_ptr = key + src_key_head_idx;
|
||||||
|
const scalar_t *src_value_head_ptr = value + src_value_head_idx;
|
||||||
|
const int64_t block_index = slot_idx / block_size;
|
||||||
|
const int64_t block_offset = slot_idx % block_size;
|
||||||
|
scalar_t *target_key_head_ptr = key_cache +
|
||||||
|
block_elem_num * block_index +
|
||||||
|
head_idx * block_size * head_size;
|
||||||
|
scalar_t *target_value_head_ptr = value_cache +
|
||||||
|
block_elem_num * block_index +
|
||||||
|
head_idx * block_size * head_size;
|
||||||
|
|
||||||
|
for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) {
|
||||||
|
const int64_t target_offset =
|
||||||
|
src_key_idx * block_size + block_offset * x;
|
||||||
|
for (int i = 0; i < x; ++i) {
|
||||||
|
target_key_head_ptr[target_offset + i] =
|
||||||
|
src_key_head_ptr[src_key_idx + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int src_value_idx = 0; src_value_idx < head_size;
|
||||||
|
++src_value_idx) {
|
||||||
|
const int64_t target_offset =
|
||||||
|
src_value_idx * block_size + block_offset;
|
||||||
|
target_value_head_ptr[target_offset] =
|
||||||
|
src_value_head_ptr[src_value_idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}; // namespace
|
||||||
|
|
||||||
|
void copy_blocks(std::vector<torch::Tensor> &key_caches,
|
||||||
|
std::vector<torch::Tensor> &value_caches,
|
||||||
|
const std::map<int64_t, std::vector<int64_t>> &block_mapping) {
|
||||||
|
int num_layers = key_caches.size();
|
||||||
|
TORCH_CHECK(num_layers == value_caches.size());
|
||||||
|
if (num_layers == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<int64_t, int64_t>> mapping_pairs;
|
||||||
|
mapping_pairs.reserve(block_mapping.size());
|
||||||
|
for (const auto &pair : block_mapping) {
|
||||||
|
for (const auto &dst : pair.second) {
|
||||||
|
mapping_pairs.emplace_back(pair.first, dst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const int element_num_per_block = key_caches[0][0].numel();
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
|
||||||
|
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
|
||||||
|
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, mapping_pairs,
|
||||||
|
element_num_per_block, num_layers);
|
||||||
|
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
|
||||||
|
torch::Tensor &key_cache, torch::Tensor &value_cache,
|
||||||
|
torch::Tensor &slot_mapping,
|
||||||
|
const std::string &kv_cache_dtype, float kv_scale) {
|
||||||
|
TORCH_CHECK(kv_scale == 1.0f);
|
||||||
|
|
||||||
|
int num_tokens = key.size(0);
|
||||||
|
int num_heads = key.size(1);
|
||||||
|
int head_size = key.size(2);
|
||||||
|
int block_size = key_cache.size(3);
|
||||||
|
int x = key_cache.size(4);
|
||||||
|
|
||||||
|
int key_stride = key.stride(0);
|
||||||
|
int value_stride = value.stride(0);
|
||||||
|
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
|
||||||
|
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
|
||||||
|
reshape_and_cache_cpu_impl<scalar_t>(
|
||||||
|
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||||
|
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
|
||||||
|
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride,
|
||||||
|
value_stride, num_heads, head_size, block_size, x);
|
||||||
|
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void swap_blocks(torch::Tensor &src, torch::Tensor &dst,
|
||||||
|
const std::map<int64_t, int64_t> &block_mapping) {
|
||||||
|
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
|
||||||
|
}
|
||||||
352
csrc/cpu/cpu_types.hpp
Normal file
352
csrc/cpu/cpu_types.hpp
Normal file
@@ -0,0 +1,352 @@
|
|||||||
|
|
||||||
|
#ifndef CPU_TYPES_HPP
|
||||||
|
#define CPU_TYPES_HPP
|
||||||
|
|
||||||
|
#include <immintrin.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
namespace vec_op {
|
||||||
|
|
||||||
|
// FIXME: FP16 is not fully supported in Torch-CPU
|
||||||
|
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||||
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
|
#ifndef CPU_OP_GUARD
|
||||||
|
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||||
|
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||||
|
#else
|
||||||
|
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||||
|
std::cout << #NAME << " invoked." << std::endl;
|
||||||
|
#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename T, T... indexes, typename F>
|
||||||
|
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
|
||||||
|
(f(std::integral_constant<T, indexes>{}), ...);
|
||||||
|
}
|
||||||
|
}; // namespace
|
||||||
|
|
||||||
|
template <typename T, T count, typename F,
|
||||||
|
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
|
||||||
|
constexpr void unroll_loop(F &&f) {
|
||||||
|
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T> struct Vec {
|
||||||
|
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FP32Vec8;
|
||||||
|
struct FP32Vec16;
|
||||||
|
|
||||||
|
#ifdef __AVX512FP16__
|
||||||
|
struct FP16Vec8 : public Vec<FP16Vec8> {
|
||||||
|
constexpr static int VEC_ELEM_NUM = 8;
|
||||||
|
|
||||||
|
__m128h reg;
|
||||||
|
|
||||||
|
explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {}
|
||||||
|
|
||||||
|
explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {}
|
||||||
|
|
||||||
|
explicit FP16Vec8(__m128h data) : reg(data) {}
|
||||||
|
|
||||||
|
FP16Vec8 operator*(const FP16Vec8 &b) const {
|
||||||
|
return FP16Vec8(_mm_mul_ph(reg, b.reg));
|
||||||
|
}
|
||||||
|
|
||||||
|
FP16Vec8 operator+(const FP16Vec8 &b) const {
|
||||||
|
return FP16Vec8(_mm_add_ph(reg, b.reg));
|
||||||
|
}
|
||||||
|
|
||||||
|
FP16Vec8 operator-(const FP16Vec8 &b) const {
|
||||||
|
return FP16Vec8(_mm_sub_ph(reg, b.reg));
|
||||||
|
}
|
||||||
|
|
||||||
|
FP16Vec8 operator/(const FP16Vec8 &b) const {
|
||||||
|
return FP16Vec8(_mm_div_ph(reg, b.reg));
|
||||||
|
}
|
||||||
|
|
||||||
|
void save(void *ptr) const { _mm_storeu_ph(ptr, reg); }
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
struct BF16Vec8 : public Vec<BF16Vec8> {
|
||||||
|
constexpr static int VEC_ELEM_NUM = 8;
|
||||||
|
|
||||||
|
__m128i reg;
|
||||||
|
|
||||||
|
explicit BF16Vec8(const void *ptr)
|
||||||
|
: reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}
|
||||||
|
|
||||||
|
explicit BF16Vec8(const FP32Vec8 &);
|
||||||
|
|
||||||
|
void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||||
|
constexpr static int VEC_ELEM_NUM = 16;
|
||||||
|
|
||||||
|
__m256i reg;
|
||||||
|
|
||||||
|
explicit BF16Vec16(const void *ptr)
|
||||||
|
: reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}
|
||||||
|
|
||||||
|
explicit BF16Vec16(const FP32Vec16 &);
|
||||||
|
|
||||||
|
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||||
|
constexpr static int VEC_ELEM_NUM = 32;
|
||||||
|
|
||||||
|
__m512i reg;
|
||||||
|
|
||||||
|
explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
|
||||||
|
|
||||||
|
explicit BF16Vec32(__m512i data) : reg(data) {}
|
||||||
|
|
||||||
|
explicit BF16Vec32(BF16Vec8 &vec8_data)
|
||||||
|
: reg((__m512i)_mm512_inserti32x4(
|
||||||
|
_mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512(
|
||||||
|
(__m128i)vec8_data.reg),
|
||||||
|
(__m128i)vec8_data.reg, 1),
|
||||||
|
(__m128i)vec8_data.reg, 2),
|
||||||
|
(__m128i)vec8_data.reg, 3)) {}
|
||||||
|
|
||||||
|
void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FP32Vec4 : public Vec<FP32Vec4> {
|
||||||
|
constexpr static int VEC_ELEM_NUM = 4;
|
||||||
|
union AliasReg {
|
||||||
|
__m128 reg;
|
||||||
|
float values[VEC_ELEM_NUM];
|
||||||
|
};
|
||||||
|
|
||||||
|
__m128 reg;
|
||||||
|
|
||||||
|
explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {}
|
||||||
|
|
||||||
|
explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {}
|
||||||
|
|
||||||
|
explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {}
|
||||||
|
|
||||||
|
explicit FP32Vec4(__m128 data) : reg(data) {}
|
||||||
|
|
||||||
|
explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||||
|
constexpr static int VEC_ELEM_NUM = 8;
|
||||||
|
union AliasReg {
|
||||||
|
__m256 reg;
|
||||||
|
float values[VEC_ELEM_NUM];
|
||||||
|
};
|
||||||
|
|
||||||
|
__m256 reg;
|
||||||
|
|
||||||
|
explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {}
|
||||||
|
|
||||||
|
explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {}
|
||||||
|
|
||||||
|
explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {}
|
||||||
|
|
||||||
|
explicit FP32Vec8(__m256 data) : reg(data) {}
|
||||||
|
|
||||||
|
explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}
|
||||||
|
|
||||||
|
#ifdef __AVX512FP16__
|
||||||
|
explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
explicit FP32Vec8(const BF16Vec8 &v)
|
||||||
|
: reg(_mm256_castsi256_ps(
|
||||||
|
_mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {}
|
||||||
|
|
||||||
|
float reduce_sum() const {
|
||||||
|
AliasReg ar;
|
||||||
|
ar.reg = reg;
|
||||||
|
float result = 0;
|
||||||
|
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
FP32Vec8 exp() const {
|
||||||
|
AliasReg ar;
|
||||||
|
ar.reg = reg;
|
||||||
|
return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]),
|
||||||
|
expf(ar.values[5]), expf(ar.values[4]),
|
||||||
|
expf(ar.values[3]), expf(ar.values[2]),
|
||||||
|
expf(ar.values[1]), expf(ar.values[0])));
|
||||||
|
}
|
||||||
|
|
||||||
|
FP32Vec8 tanh() const {
|
||||||
|
AliasReg ar;
|
||||||
|
ar.reg = reg;
|
||||||
|
return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]),
|
||||||
|
tanhf(ar.values[5]), tanhf(ar.values[4]),
|
||||||
|
tanhf(ar.values[3]), tanhf(ar.values[2]),
|
||||||
|
tanhf(ar.values[1]), tanhf(ar.values[0])));
|
||||||
|
}
|
||||||
|
|
||||||
|
FP32Vec8 er() const {
|
||||||
|
AliasReg ar;
|
||||||
|
ar.reg = reg;
|
||||||
|
return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]),
|
||||||
|
erf(ar.values[5]), erf(ar.values[4]),
|
||||||
|
erf(ar.values[3]), erf(ar.values[2]),
|
||||||
|
erf(ar.values[1]), erf(ar.values[0])));
|
||||||
|
}
|
||||||
|
|
||||||
|
FP32Vec8 operator*(const FP32Vec8 &b) const {
|
||||||
|
return FP32Vec8(_mm256_mul_ps(reg, b.reg));
|
||||||
|
}
|
||||||
|
|
||||||
|
FP32Vec8 operator+(const FP32Vec8 &b) const {
|
||||||
|
return FP32Vec8(_mm256_add_ps(reg, b.reg));
|
||||||
|
}
|
||||||
|
|
||||||
|
FP32Vec8 operator-(const FP32Vec8 &b) const {
|
||||||
|
return FP32Vec8(_mm256_sub_ps(reg, b.reg));
|
||||||
|
}
|
||||||
|
|
||||||
|
FP32Vec8 operator/(const FP32Vec8 &b) const {
|
||||||
|
return FP32Vec8(_mm256_div_ps(reg, b.reg));
|
||||||
|
}
|
||||||
|
|
||||||
|
void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||||
|
constexpr static int VEC_ELEM_NUM = 16;
|
||||||
|
union AliasReg {
|
||||||
|
__m512 reg;
|
||||||
|
float values[VEC_ELEM_NUM];
|
||||||
|
};
|
||||||
|
|
||||||
|
__m512 reg;
|
||||||
|
|
||||||
|
explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {}
|
||||||
|
|
||||||
|
explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
|
||||||
|
|
||||||
|
explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {}
|
||||||
|
|
||||||
|
explicit FP32Vec16(__m512 data) : reg(data) {}
|
||||||
|
|
||||||
|
explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {}
|
||||||
|
|
||||||
|
explicit FP32Vec16(const FP32Vec4 &data)
|
||||||
|
: reg((__m512)_mm512_inserti32x4(
|
||||||
|
_mm512_inserti32x4(
|
||||||
|
_mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg),
|
||||||
|
(__m128i)data.reg, 1),
|
||||||
|
(__m128i)data.reg, 2),
|
||||||
|
(__m128i)data.reg, 3)) {}
|
||||||
|
|
||||||
|
explicit FP32Vec16(const FP32Vec8 &data)
|
||||||
|
: reg((__m512)_mm512_inserti32x8(
|
||||||
|
_mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {}
|
||||||
|
|
||||||
|
explicit FP32Vec16(const BF16Vec16 &v)
|
||||||
|
: reg(_mm512_castsi512_ps(
|
||||||
|
_mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
|
||||||
|
|
||||||
|
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
||||||
|
|
||||||
|
FP32Vec16 operator*(const FP32Vec16 &b) const {
|
||||||
|
return FP32Vec16(_mm512_mul_ps(reg, b.reg));
|
||||||
|
}
|
||||||
|
|
||||||
|
FP32Vec16 operator+(const FP32Vec16 &b) const {
|
||||||
|
return FP32Vec16(_mm512_add_ps(reg, b.reg));
|
||||||
|
}
|
||||||
|
|
||||||
|
FP32Vec16 operator-(const FP32Vec16 &b) const {
|
||||||
|
return FP32Vec16(_mm512_sub_ps(reg, b.reg));
|
||||||
|
}
|
||||||
|
|
||||||
|
FP32Vec16 operator/(const FP32Vec16 &b) const {
|
||||||
|
return FP32Vec16(_mm512_div_ps(reg, b.reg));
|
||||||
|
}
|
||||||
|
|
||||||
|
float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
|
||||||
|
|
||||||
|
template <int group_size> float reduce_sub_sum(int idx) {
|
||||||
|
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||||
|
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
|
||||||
|
__mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
|
||||||
|
return _mm512_mask_reduce_add_ps(mask, reg);
|
||||||
|
}
|
||||||
|
|
||||||
|
void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T> struct VecType { using vec_type = void; };
|
||||||
|
|
||||||
|
template <typename T> using vec_t = typename VecType<T>::vec_type;
|
||||||
|
|
||||||
|
template <> struct VecType<float> { using vec_type = FP32Vec8; };
|
||||||
|
|
||||||
|
#ifdef __AVX512FP16__
|
||||||
|
template <> struct VecType<c10::Half> { using vec_type = FP16Vec16; };
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
|
||||||
|
|
||||||
|
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
|
||||||
|
|
||||||
|
#ifdef __AVX512FP16__
|
||||||
|
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
|
||||||
|
*reinterpret_cast<_Float16 *>(ptr) = v;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
|
||||||
|
acc = acc + a * b;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef __AVX512BF16__
|
||||||
|
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
||||||
|
*reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
|
||||||
|
: reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {}
|
||||||
|
|
||||||
|
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
|
||||||
|
: reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {}
|
||||||
|
|
||||||
|
inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
|
||||||
|
acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
||||||
|
c10::BFloat16 __attribute__((__may_alias__)) *v_ptr =
|
||||||
|
reinterpret_cast<c10::BFloat16 *>(&v);
|
||||||
|
*ptr = *(v_ptr + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
|
||||||
|
: reg(_mm256_cvtepi32_epi16(
|
||||||
|
_mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {}
|
||||||
|
|
||||||
|
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
|
||||||
|
: reg(_mm512_cvtepi32_epi16(
|
||||||
|
_mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); }
|
||||||
|
|
||||||
|
}; // namespace vec_op
|
||||||
|
|
||||||
|
#endif
|
||||||
117
csrc/cpu/layernorm.cpp
Normal file
117
csrc/cpu/layernorm.cpp
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
#include "cpu_types.hpp"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename scalar_t>
|
||||||
|
void rms_norm_impl(scalar_t *__restrict__ out,
|
||||||
|
const scalar_t *__restrict__ input,
|
||||||
|
const scalar_t *__restrict__ weight, const float epsilon,
|
||||||
|
const int num_tokens, const int hidden_size) {
|
||||||
|
using scalar_vec_t = vec_op::vec_t<scalar_t>;
|
||||||
|
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
||||||
|
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
|
||||||
|
|
||||||
|
#pragma omp parallel for
|
||||||
|
for (int i = 0; i < num_tokens; ++i) {
|
||||||
|
vec_op::FP32Vec8 variance(0.0);
|
||||||
|
auto input_p = input + i * hidden_size;
|
||||||
|
auto output_p = out + i * hidden_size;
|
||||||
|
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
|
||||||
|
scalar_vec_t x(input_p + j);
|
||||||
|
vec_op::FP32Vec8 fp32_x(x);
|
||||||
|
variance = variance + fp32_x * fp32_x;
|
||||||
|
}
|
||||||
|
|
||||||
|
float s_variance =
|
||||||
|
1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon);
|
||||||
|
vec_op::FP32Vec8 fp32_s_variance(s_variance);
|
||||||
|
|
||||||
|
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
|
||||||
|
scalar_vec_t x(input_p + j);
|
||||||
|
scalar_vec_t w(weight + j);
|
||||||
|
|
||||||
|
vec_op::FP32Vec8 fp32_x(x);
|
||||||
|
vec_op::FP32Vec8 fp32_w(w);
|
||||||
|
|
||||||
|
vec_op::FP32Vec8 fp32_out = fp32_x * fp32_s_variance * fp32_w;
|
||||||
|
|
||||||
|
scalar_vec_t out(fp32_out);
|
||||||
|
out.save(output_p + j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
|
||||||
|
scalar_t *__restrict__ residual,
|
||||||
|
const scalar_t *__restrict__ weight,
|
||||||
|
const float epsilon, const int num_tokens,
|
||||||
|
const int hidden_size) {
|
||||||
|
using scalar_vec_t = vec_op::vec_t<scalar_t>;
|
||||||
|
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
||||||
|
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
|
||||||
|
|
||||||
|
#pragma omp parallel for
|
||||||
|
for (int i = 0; i < num_tokens; ++i) {
|
||||||
|
vec_op::FP32Vec8 variance(0.0);
|
||||||
|
auto input_p = input + i * hidden_size;
|
||||||
|
auto residual_p = residual + i * hidden_size;
|
||||||
|
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
|
||||||
|
scalar_vec_t x(input_p + j);
|
||||||
|
scalar_vec_t res(residual_p + j);
|
||||||
|
vec_op::FP32Vec8 fp32_x(x);
|
||||||
|
vec_op::FP32Vec8 fp32_res(res);
|
||||||
|
|
||||||
|
fp32_x = fp32_x + fp32_res;
|
||||||
|
variance = variance + fp32_x * fp32_x;
|
||||||
|
scalar_vec_t out(fp32_x);
|
||||||
|
out.save(residual_p + j);
|
||||||
|
}
|
||||||
|
|
||||||
|
float s_variance =
|
||||||
|
1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon);
|
||||||
|
vec_op::FP32Vec8 fp32_s_variance(s_variance);
|
||||||
|
|
||||||
|
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
|
||||||
|
scalar_vec_t w(weight + j);
|
||||||
|
scalar_vec_t res(residual_p + j);
|
||||||
|
|
||||||
|
vec_op::FP32Vec8 fp32_w(w);
|
||||||
|
vec_op::FP32Vec8 fp32_res(res);
|
||||||
|
|
||||||
|
vec_op::FP32Vec8 fp32_out = fp32_res * fp32_s_variance * fp32_w;
|
||||||
|
|
||||||
|
scalar_vec_t out(fp32_out);
|
||||||
|
out.save(input_p + j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void rms_norm(torch::Tensor &out, torch::Tensor &input,
|
||||||
|
torch::Tensor &weight, float epsilon) {
|
||||||
|
int hidden_size = input.size(-1);
|
||||||
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] {
|
||||||
|
CPU_KERNEL_GUARD_IN(rms_norm_impl)
|
||||||
|
rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(), epsilon, num_tokens,
|
||||||
|
hidden_size);
|
||||||
|
CPU_KERNEL_GUARD_OUT(rms_norm_impl)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual,
|
||||||
|
torch::Tensor &weight, float epsilon) {
|
||||||
|
int hidden_size = input.size(-1);
|
||||||
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
input.scalar_type(), "fused_add_rms_norm_impl", [&] {
|
||||||
|
CPU_KERNEL_GUARD_IN(fused_add_rms_norm_impl)
|
||||||
|
fused_add_rms_norm_impl(
|
||||||
|
input.data_ptr<scalar_t>(), residual.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
|
||||||
|
CPU_KERNEL_GUARD_OUT(fused_add_rms_norm_impl)
|
||||||
|
});
|
||||||
|
}
|
||||||
199
csrc/cpu/pos_encoding.cpp
Normal file
199
csrc/cpu/pos_encoding.cpp
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
|
||||||
|
#include "cpu_types.hpp"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename scalar_t>
|
||||||
|
void rotary_embedding_impl(
|
||||||
|
const int64_t
|
||||||
|
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
||||||
|
scalar_t
|
||||||
|
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or
|
||||||
|
/// [num_tokens, num_heads, head_size]
|
||||||
|
scalar_t
|
||||||
|
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or
|
||||||
|
// [num_tokens, num_kv_heads, head_size]
|
||||||
|
const scalar_t
|
||||||
|
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||||
|
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||||
|
const int num_heads, const int num_kv_heads, const int head_size,
|
||||||
|
const int num_tokens) {
|
||||||
|
using scalar_vec_t = vec_op::vec_t<scalar_t>;
|
||||||
|
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
||||||
|
constexpr int ELEM_SIZE = sizeof(scalar_t);
|
||||||
|
|
||||||
|
const int embed_dim = rot_dim / 2;
|
||||||
|
TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0);
|
||||||
|
|
||||||
|
#pragma omp parallel for
|
||||||
|
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
|
int64_t pos = positions[token_idx];
|
||||||
|
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||||
|
|
||||||
|
for (int i = 0; i < num_heads; ++i) {
|
||||||
|
const int head_idx = i;
|
||||||
|
const int64_t token_head =
|
||||||
|
token_idx * query_stride + head_idx * head_size;
|
||||||
|
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) {
|
||||||
|
const int rot_offset = j;
|
||||||
|
const int x_index = rot_offset;
|
||||||
|
const int y_index = embed_dim + rot_offset;
|
||||||
|
|
||||||
|
const int64_t out_x = token_head + x_index;
|
||||||
|
const int64_t out_y = token_head + y_index;
|
||||||
|
|
||||||
|
const scalar_vec_t cos(cache_ptr + x_index);
|
||||||
|
const scalar_vec_t sin(cache_ptr + y_index);
|
||||||
|
|
||||||
|
const scalar_vec_t q_x(query + out_x);
|
||||||
|
const scalar_vec_t q_y(query + out_y);
|
||||||
|
|
||||||
|
vec_op::FP32Vec8 fp32_cos(cos);
|
||||||
|
vec_op::FP32Vec8 fp32_sin(sin);
|
||||||
|
|
||||||
|
vec_op::FP32Vec8 fp32_q_x(q_x);
|
||||||
|
vec_op::FP32Vec8 fp32_q_y(q_y);
|
||||||
|
|
||||||
|
auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
|
||||||
|
scalar_vec_t(out1).save(query + out_x);
|
||||||
|
|
||||||
|
auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
|
||||||
|
scalar_vec_t(out2).save(query + out_y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < num_kv_heads; ++i) {
|
||||||
|
const int head_idx = i;
|
||||||
|
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||||
|
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) {
|
||||||
|
const int rot_offset = j;
|
||||||
|
const int x_index = rot_offset;
|
||||||
|
const int y_index = embed_dim + rot_offset;
|
||||||
|
|
||||||
|
const int64_t out_x = token_head + x_index;
|
||||||
|
const int64_t out_y = token_head + y_index;
|
||||||
|
|
||||||
|
const scalar_vec_t cos(cache_ptr + x_index);
|
||||||
|
const scalar_vec_t sin(cache_ptr + y_index);
|
||||||
|
|
||||||
|
const scalar_vec_t k_x(key + out_x);
|
||||||
|
const scalar_vec_t k_y(key + out_y);
|
||||||
|
|
||||||
|
vec_op::FP32Vec8 fp32_cos(cos);
|
||||||
|
vec_op::FP32Vec8 fp32_sin(sin);
|
||||||
|
|
||||||
|
vec_op::FP32Vec8 fp32_k_x(k_x);
|
||||||
|
vec_op::FP32Vec8 fp32_k_y(k_y);
|
||||||
|
|
||||||
|
auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin;
|
||||||
|
scalar_vec_t(out1).save(key + out_x);
|
||||||
|
auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin;
|
||||||
|
scalar_vec_t(out2).save(key + out_y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
void rotary_embedding_gptj_impl(
|
||||||
|
const int64_t
|
||||||
|
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
||||||
|
scalar_t
|
||||||
|
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or
|
||||||
|
/// [num_tokens, num_heads, head_size]
|
||||||
|
scalar_t
|
||||||
|
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or
|
||||||
|
// [num_tokens, num_kv_heads, head_size]
|
||||||
|
const scalar_t
|
||||||
|
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||||
|
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||||
|
const int num_heads, const int num_kv_heads, const int head_size,
|
||||||
|
const int num_tokens) {
|
||||||
|
const int embed_dim = rot_dim / 2;
|
||||||
|
|
||||||
|
#pragma omp parallel for collapse(2)
|
||||||
|
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
|
for (int i = 0; i < num_heads; ++i) {
|
||||||
|
int64_t pos = positions[token_idx];
|
||||||
|
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||||
|
const scalar_t *cos_cache_ptr = cache_ptr;
|
||||||
|
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim;
|
||||||
|
const int head_idx = i;
|
||||||
|
const int64_t token_head =
|
||||||
|
token_idx * query_stride + head_idx * head_size;
|
||||||
|
scalar_t *head_query = token_head + query;
|
||||||
|
for (int j = 0; j < embed_dim; j += 1) {
|
||||||
|
const int rot_offset = j;
|
||||||
|
const int x_index = 2 * rot_offset;
|
||||||
|
const int y_index = 2 * rot_offset + 1;
|
||||||
|
|
||||||
|
const float cos = cos_cache_ptr[rot_offset];
|
||||||
|
const float sin = sin_cache_ptr[rot_offset];
|
||||||
|
|
||||||
|
const float x = head_query[x_index];
|
||||||
|
const float y = head_query[y_index];
|
||||||
|
|
||||||
|
head_query[x_index] = x * cos - y * sin;
|
||||||
|
head_query[y_index] = y * cos + x * sin;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma omp parallel for collapse(2)
|
||||||
|
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
|
for (int i = 0; i < num_kv_heads; ++i) {
|
||||||
|
int64_t pos = positions[token_idx];
|
||||||
|
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||||
|
const scalar_t *cos_cache_ptr = cache_ptr;
|
||||||
|
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim;
|
||||||
|
const int head_idx = i;
|
||||||
|
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||||
|
scalar_t *head_key = key + token_head;
|
||||||
|
for (int j = 0; j < embed_dim; j += 1) {
|
||||||
|
const int rot_offset = j;
|
||||||
|
const int x_index = 2 * rot_offset;
|
||||||
|
const int y_index = 2 * rot_offset + 1;
|
||||||
|
|
||||||
|
const float cos = cos_cache_ptr[rot_offset];
|
||||||
|
const float sin = sin_cache_ptr[rot_offset];
|
||||||
|
|
||||||
|
const float x = head_key[x_index];
|
||||||
|
const float y = head_key[y_index];
|
||||||
|
|
||||||
|
head_key[x_index] = x * cos - y * sin;
|
||||||
|
head_key[y_index] = y * cos + x * sin;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}; // namespace
|
||||||
|
|
||||||
|
void rotary_embedding(torch::Tensor &positions, torch::Tensor &query,
|
||||||
|
torch::Tensor &key, int head_size,
|
||||||
|
torch::Tensor &cos_sin_cache, bool is_neox) {
|
||||||
|
int num_tokens = query.numel() / query.size(-1);
|
||||||
|
int rot_dim = cos_sin_cache.size(1);
|
||||||
|
int num_heads = query.size(-1) / head_size;
|
||||||
|
int num_kv_heads = key.size(-1) / head_size;
|
||||||
|
int64_t key_stride = key.stride(-2);
|
||||||
|
int64_t query_stride = query.stride(-2);
|
||||||
|
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
query.scalar_type(), "rotary_embedding_impl", [&] {
|
||||||
|
CPU_KERNEL_GUARD_IN(rotary_embedding_impl)
|
||||||
|
if (is_neox) {
|
||||||
|
rotary_embedding_impl(
|
||||||
|
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
|
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
|
||||||
|
head_size, num_tokens);
|
||||||
|
} else {
|
||||||
|
rotary_embedding_gptj_impl(
|
||||||
|
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
|
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
|
||||||
|
head_size, num_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
CPU_KERNEL_GUARD_OUT(rotary_embedding_impl)
|
||||||
|
});
|
||||||
|
}
|
||||||
73
csrc/cpu/pybind.cpp
Normal file
73
csrc/cpu/pybind.cpp
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
#include "cache.h"
|
||||||
|
#include "cuda_utils.h"
|
||||||
|
#include "ops.h"
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
// vLLM custom ops
|
||||||
|
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
||||||
|
|
||||||
|
// Attention ops
|
||||||
|
ops.def(
|
||||||
|
"paged_attention_v1",
|
||||||
|
&paged_attention_v1,
|
||||||
|
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
||||||
|
ops.def(
|
||||||
|
"paged_attention_v2",
|
||||||
|
&paged_attention_v2,
|
||||||
|
"PagedAttention V2.");
|
||||||
|
|
||||||
|
// Activation ops
|
||||||
|
ops.def(
|
||||||
|
"silu_and_mul",
|
||||||
|
&silu_and_mul,
|
||||||
|
"Activation function used in SwiGLU.");
|
||||||
|
ops.def(
|
||||||
|
"gelu_and_mul",
|
||||||
|
&gelu_and_mul,
|
||||||
|
"Activation function used in GeGLU with `none` approximation.");
|
||||||
|
ops.def(
|
||||||
|
"gelu_tanh_and_mul",
|
||||||
|
&gelu_tanh_and_mul,
|
||||||
|
"Activation function used in GeGLU with `tanh` approximation.");
|
||||||
|
ops.def(
|
||||||
|
"gelu_new",
|
||||||
|
&gelu_new,
|
||||||
|
"GELU implementation used in GPT-2.");
|
||||||
|
ops.def(
|
||||||
|
"gelu_fast",
|
||||||
|
&gelu_fast,
|
||||||
|
"Approximate GELU implementation.");
|
||||||
|
|
||||||
|
// Layernorm
|
||||||
|
ops.def(
|
||||||
|
"rms_norm",
|
||||||
|
&rms_norm,
|
||||||
|
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||||
|
|
||||||
|
ops.def(
|
||||||
|
"fused_add_rms_norm",
|
||||||
|
&fused_add_rms_norm,
|
||||||
|
"In-place fused Add and RMS Normalization");
|
||||||
|
|
||||||
|
// Rotary embedding
|
||||||
|
ops.def(
|
||||||
|
"rotary_embedding",
|
||||||
|
&rotary_embedding,
|
||||||
|
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||||
|
|
||||||
|
// Cache ops
|
||||||
|
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
||||||
|
cache_ops.def(
|
||||||
|
"swap_blocks",
|
||||||
|
&swap_blocks,
|
||||||
|
"Swap in (out) the cache blocks from src to dst");
|
||||||
|
cache_ops.def(
|
||||||
|
"copy_blocks",
|
||||||
|
©_blocks,
|
||||||
|
"Copy the cache blocks from src to dst");
|
||||||
|
cache_ops.def(
|
||||||
|
"reshape_and_cache",
|
||||||
|
&reshape_and_cache,
|
||||||
|
"Reshape the key and value tensors and cache them");
|
||||||
|
}
|
||||||
@@ -4,6 +4,16 @@
|
|||||||
|
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
#include "reduction_utils.cuh"
|
#include "reduction_utils.cuh"
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#else
|
||||||
|
#include <hip/hip_bf16.h>
|
||||||
|
#include <hip/hip_fp16.h>
|
||||||
|
|
||||||
|
using __nv_bfloat16 = __hip_bfloat16;
|
||||||
|
using __nv_bfloat162 = __hip_bfloat162;
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
@@ -35,9 +45,201 @@ __global__ void rms_norm_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Further optimize this kernel.
|
|
||||||
template<typename scalar_t>
|
/* Converter structs for the conversion from torch types to HIP/CUDA types,
|
||||||
__global__ void fused_add_rms_norm_kernel(
|
and the associated type conversions within HIP/CUDA. These helpers need
|
||||||
|
to be implemented for now because the relevant type conversion
|
||||||
|
operators/constructors are not consistently implemented by HIP/CUDA, so
|
||||||
|
a generic conversion via type casts cannot be implemented.
|
||||||
|
|
||||||
|
Each struct should have the member static constexpr bool `exists`:
|
||||||
|
If false, the optimized kernel is not used for the corresponding torch type.
|
||||||
|
If true, the struct should be fully defined as shown in the examples below.
|
||||||
|
*/
|
||||||
|
template<typename torch_type>
|
||||||
|
struct _typeConvert { static constexpr bool exists = false; };
|
||||||
|
|
||||||
|
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||||||
|
// CUDA < 12.0 runs into issues with packed type conversion
|
||||||
|
template<>
|
||||||
|
struct _typeConvert<c10::Half> {
|
||||||
|
static constexpr bool exists = true;
|
||||||
|
using hip_type = __half;
|
||||||
|
using packed_hip_type = __half2;
|
||||||
|
|
||||||
|
__device__ static inline float convert(hip_type x) { return __half2float(x); }
|
||||||
|
__device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); }
|
||||||
|
__device__ static inline hip_type convert(float x) { return __float2half_rn(x); }
|
||||||
|
__device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); }
|
||||||
|
};
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
// CUDA_ARCH < 800 does not have BF16 support
|
||||||
|
// TODO: Add in ROCm support once public headers handle bf16 maturely
|
||||||
|
template<>
|
||||||
|
struct _typeConvert<c10::BFloat16> {
|
||||||
|
static constexpr bool exists = true;
|
||||||
|
using hip_type = __nv_bfloat16;
|
||||||
|
using packed_hip_type = __nv_bfloat162;
|
||||||
|
|
||||||
|
__device__ static inline float convert(hip_type x) { return __bfloat162float(x); }
|
||||||
|
__device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); }
|
||||||
|
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); }
|
||||||
|
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); }
|
||||||
|
};
|
||||||
|
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||||||
|
|
||||||
|
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
|
||||||
|
for appropriate specializations of fused_add_rms_norm_kernel.
|
||||||
|
Only functions that are necessary in that kernel are implemented.
|
||||||
|
Alignment to 16 bytes is required to use 128-bit global memory ops.
|
||||||
|
*/
|
||||||
|
template<typename scalar_t, int width>
|
||||||
|
struct alignas(16) _f16Vec {
|
||||||
|
/* Not theoretically necessary that width is a power of 2 but should
|
||||||
|
almost always be the case for optimization purposes */
|
||||||
|
static_assert(width > 0 && (width & (width - 1)) == 0,
|
||||||
|
"Width is not a positive power of 2!");
|
||||||
|
using Converter = _typeConvert<scalar_t>;
|
||||||
|
using T1 = typename Converter::hip_type;
|
||||||
|
using T2 = typename Converter::packed_hip_type;
|
||||||
|
T1 data[width];
|
||||||
|
|
||||||
|
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
|
||||||
|
if constexpr (width % 2 == 0) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; i += 2) {
|
||||||
|
T2 temp{data[i], data[i+1]};
|
||||||
|
temp += T2{other.data[i], other.data[i+1]};
|
||||||
|
data[i] = temp.x;
|
||||||
|
data[i+1] = temp.y;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; ++i)
|
||||||
|
data[i] += other.data[i];
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
|
||||||
|
if constexpr (width % 2 == 0) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; i += 2) {
|
||||||
|
T2 temp{data[i], data[i+1]};
|
||||||
|
temp *= T2{other.data[i], other.data[i+1]};
|
||||||
|
data[i] = temp.x;
|
||||||
|
data[i+1] = temp.y;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; ++i)
|
||||||
|
data[i] *= other.data[i];
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ _f16Vec& operator*=(const float scale) {
|
||||||
|
if constexpr (width % 2 == 0) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; i += 2) {
|
||||||
|
float2 temp_f = Converter::convert(T2{data[i], data[i+1]});
|
||||||
|
temp_f.x *= scale;
|
||||||
|
temp_f.y *= scale;
|
||||||
|
T2 temp = Converter::convert(temp_f);
|
||||||
|
data[i] = temp.x;
|
||||||
|
data[i+1] = temp.y;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; ++i) {
|
||||||
|
float temp = Converter::convert(data[i]) * scale;
|
||||||
|
data[i] = Converter::convert(temp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ float sum_squares() const {
|
||||||
|
float result = 0.0f;
|
||||||
|
if constexpr (width % 2 == 0) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; i += 2) {
|
||||||
|
float2 z = Converter::convert(T2{data[i], data[i+1]});
|
||||||
|
result += z.x * z.x + z.y * z.y;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; ++i) {
|
||||||
|
float x = Converter::convert(data[i]);
|
||||||
|
result += x * x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/* Function specialization in the case of FP16/BF16 tensors.
|
||||||
|
Additional optimizations we can make in this case are
|
||||||
|
packed and vectorized operations, which help with the
|
||||||
|
memory latency bottleneck. */
|
||||||
|
template<typename scalar_t, int width>
|
||||||
|
__global__ std::enable_if_t<
|
||||||
|
(width > 0) && _typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
|
||||||
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
|
const float epsilon,
|
||||||
|
const int num_tokens,
|
||||||
|
const int hidden_size) {
|
||||||
|
// Sanity checks on our vector struct and type-punned pointer arithmetic
|
||||||
|
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
|
||||||
|
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
|
||||||
|
|
||||||
|
const int vec_hidden_size = hidden_size / width;
|
||||||
|
__shared__ float s_variance;
|
||||||
|
float variance = 0.0f;
|
||||||
|
/* These and the argument pointers are all declared `restrict` as they are
|
||||||
|
not aliased in practice. Argument pointers should not be dereferenced
|
||||||
|
in this kernel as that would be undefined behavior */
|
||||||
|
auto* __restrict__ input_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
|
||||||
|
auto* __restrict__ residual_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
|
||||||
|
auto* __restrict__ weight_v = reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
|
||||||
|
|
||||||
|
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||||
|
int id = blockIdx.x * vec_hidden_size + idx;
|
||||||
|
_f16Vec<scalar_t, width> temp = input_v[id];
|
||||||
|
temp += residual_v[id];
|
||||||
|
variance += temp.sum_squares();
|
||||||
|
residual_v[id] = temp;
|
||||||
|
}
|
||||||
|
/* Keep the following if-else block in sync with the
|
||||||
|
calculation of max_block_size in fused_add_rms_norm */
|
||||||
|
if (num_tokens < 256) {
|
||||||
|
variance = blockReduceSum<float, 1024>(variance);
|
||||||
|
} else variance = blockReduceSum<float, 256>(variance);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||||
|
int id = blockIdx.x * vec_hidden_size + idx;
|
||||||
|
_f16Vec<scalar_t, width> temp = residual_v[id];
|
||||||
|
temp *= s_variance;
|
||||||
|
temp *= weight_v[idx];
|
||||||
|
input_v[id] = temp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* Generic fused_add_rms_norm_kernel
|
||||||
|
The width field is not used here but necessary for other specializations.
|
||||||
|
*/
|
||||||
|
template<typename scalar_t, int width>
|
||||||
|
__global__ std::enable_if_t<
|
||||||
|
(width == 0) || !_typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
|
||||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
@@ -48,12 +250,17 @@ __global__ void fused_add_rms_norm_kernel(
|
|||||||
float variance = 0.0f;
|
float variance = 0.0f;
|
||||||
|
|
||||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
float x = (float) input[blockIdx.x * hidden_size + idx];
|
scalar_t z = input[blockIdx.x * hidden_size + idx];
|
||||||
x += (float) residual[blockIdx.x * hidden_size + idx];
|
z += residual[blockIdx.x * hidden_size + idx];
|
||||||
|
float x = (float) z;
|
||||||
variance += x * x;
|
variance += x * x;
|
||||||
residual[blockIdx.x * hidden_size + idx] = (scalar_t) x;
|
residual[blockIdx.x * hidden_size + idx] = z;
|
||||||
}
|
}
|
||||||
variance = blockReduceSum<float>(variance);
|
/* Keep the following if-else block in sync with the
|
||||||
|
calculation of max_block_size in fused_add_rms_norm */
|
||||||
|
if (num_tokens < 256) {
|
||||||
|
variance = blockReduceSum<float, 1024>(variance);
|
||||||
|
} else variance = blockReduceSum<float, 256>(variance);
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
}
|
}
|
||||||
@@ -93,6 +300,21 @@ void rms_norm(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||||
|
input.scalar_type(), \
|
||||||
|
"fused_add_rms_norm_kernel", \
|
||||||
|
[&] { \
|
||||||
|
vllm::fused_add_rms_norm_kernel \
|
||||||
|
<scalar_t, width><<<grid, block, 0, stream>>>( \
|
||||||
|
input.data_ptr<scalar_t>(), \
|
||||||
|
residual.data_ptr<scalar_t>(), \
|
||||||
|
weight.data_ptr<scalar_t>(), \
|
||||||
|
epsilon, \
|
||||||
|
num_tokens, \
|
||||||
|
hidden_size); \
|
||||||
|
});
|
||||||
|
|
||||||
void fused_add_rms_norm(
|
void fused_add_rms_norm(
|
||||||
torch::Tensor& input, // [..., hidden_size]
|
torch::Tensor& input, // [..., hidden_size]
|
||||||
torch::Tensor& residual, // [..., hidden_size]
|
torch::Tensor& residual, // [..., hidden_size]
|
||||||
@@ -102,19 +324,29 @@ void fused_add_rms_norm(
|
|||||||
int num_tokens = input.numel() / hidden_size;
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(hidden_size, 1024));
|
/* This kernel is memory-latency bound in many scenarios.
|
||||||
|
When num_tokens is large, a smaller block size allows
|
||||||
|
for increased block occupancy on CUs and better latency
|
||||||
|
hiding on global mem ops. */
|
||||||
|
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
|
||||||
|
dim3 block(std::min(hidden_size, max_block_size));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
/*If the tensor types are FP16/BF16, try to use the optimized kernel
|
||||||
input.scalar_type(),
|
with packed + vectorized ops.
|
||||||
"fused_add_rms_norm_kernel",
|
Max optimization is achieved with a width-8 vector of FP16/BF16s
|
||||||
[&] {
|
since we can load at most 128 bits at once in a global memory op.
|
||||||
vllm::fused_add_rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
However, this requires each tensor's data to be aligned to 16
|
||||||
input.data_ptr<scalar_t>(),
|
bytes.
|
||||||
residual.data_ptr<scalar_t>(),
|
*/
|
||||||
weight.data_ptr<scalar_t>(),
|
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||||
epsilon,
|
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
|
||||||
num_tokens,
|
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||||
hidden_size);
|
bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \
|
||||||
});
|
&& wt_ptr % 16 == 0;
|
||||||
|
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
||||||
|
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||||
|
} else {
|
||||||
|
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
26
csrc/ops.h
26
csrc/ops.h
@@ -14,7 +14,8 @@ void paged_attention_v1(
|
|||||||
int block_size,
|
int block_size,
|
||||||
int max_context_len,
|
int max_context_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype);
|
const std::string& kv_cache_dtype,
|
||||||
|
float kv_scale);
|
||||||
|
|
||||||
void paged_attention_v2(
|
void paged_attention_v2(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out,
|
||||||
@@ -31,7 +32,8 @@ void paged_attention_v2(
|
|||||||
int block_size,
|
int block_size,
|
||||||
int max_context_len,
|
int max_context_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype);
|
const std::string& kv_cache_dtype,
|
||||||
|
float kv_scale);
|
||||||
|
|
||||||
void rms_norm(
|
void rms_norm(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out,
|
||||||
@@ -84,6 +86,21 @@ void gelu_fast(
|
|||||||
torch::Tensor& input);
|
torch::Tensor& input);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
|
torch::Tensor aqlm_gemm(
|
||||||
|
const torch::Tensor& input,
|
||||||
|
const torch::Tensor& codes,
|
||||||
|
const torch::Tensor& codebooks,
|
||||||
|
const torch::Tensor& scales,
|
||||||
|
const torch::Tensor& codebook_partition_sizes,
|
||||||
|
const std::optional<torch::Tensor>& bias
|
||||||
|
);
|
||||||
|
|
||||||
|
torch::Tensor aqlm_dequant(
|
||||||
|
const torch::Tensor& codes,
|
||||||
|
const torch::Tensor& codebooks,
|
||||||
|
const torch::Tensor& codebook_partition_sizes
|
||||||
|
);
|
||||||
|
|
||||||
torch::Tensor awq_gemm(
|
torch::Tensor awq_gemm(
|
||||||
torch::Tensor _in_feats,
|
torch::Tensor _in_feats,
|
||||||
torch::Tensor _kernel,
|
torch::Tensor _kernel,
|
||||||
@@ -129,6 +146,11 @@ void gptq_shuffle(
|
|||||||
torch::Tensor q_perm,
|
torch::Tensor q_perm,
|
||||||
int bit);
|
int bit);
|
||||||
|
|
||||||
|
void scaled_fp8_quant(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input,
|
||||||
|
torch::Tensor& scale);
|
||||||
|
|
||||||
void moe_align_block_size(
|
void moe_align_block_size(
|
||||||
torch::Tensor topk_ids,
|
torch::Tensor topk_ids,
|
||||||
int num_experts,
|
int num_experts,
|
||||||
|
|||||||
@@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half)
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16)
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half)
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)
|
|
||||||
@@ -14,6 +14,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||||||
f(in_T, out_T, W_T, narrow, 128) \
|
f(in_T, out_T, W_T, narrow, 128) \
|
||||||
f(in_T, out_T, W_T, narrow, 256) \
|
f(in_T, out_T, W_T, narrow, 256) \
|
||||||
f(in_T, out_T, W_T, narrow, 512) \
|
f(in_T, out_T, W_T, narrow, 512) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 640) \
|
||||||
f(in_T, out_T, W_T, narrow, 768) \
|
f(in_T, out_T, W_T, narrow, 768) \
|
||||||
f(in_T, out_T, W_T, narrow, 1024) \
|
f(in_T, out_T, W_T, narrow, 1024) \
|
||||||
f(in_T, out_T, W_T, narrow, 1152) \
|
f(in_T, out_T, W_T, narrow, 1152) \
|
||||||
@@ -46,6 +47,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||||||
f(in_T, out_T, W_T, narrow, 13696) \
|
f(in_T, out_T, W_T, narrow, 13696) \
|
||||||
f(in_T, out_T, W_T, narrow, 13824) \
|
f(in_T, out_T, W_T, narrow, 13824) \
|
||||||
f(in_T, out_T, W_T, narrow, 14336) \
|
f(in_T, out_T, W_T, narrow, 14336) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 15360) \
|
||||||
f(in_T, out_T, W_T, narrow, 16384) \
|
f(in_T, out_T, W_T, narrow, 16384) \
|
||||||
f(in_T, out_T, W_T, narrow, 20480) \
|
f(in_T, out_T, W_T, narrow, 20480) \
|
||||||
f(in_T, out_T, W_T, narrow, 22016) \
|
f(in_T, out_T, W_T, narrow, 22016) \
|
||||||
@@ -58,8 +60,19 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||||||
f(in_T, out_T, W_T, narrow, 32768) \
|
f(in_T, out_T, W_T, narrow, 32768) \
|
||||||
f(in_T, out_T, W_T, narrow, 33024) \
|
f(in_T, out_T, W_T, narrow, 33024) \
|
||||||
f(in_T, out_T, W_T, narrow, 36864) \
|
f(in_T, out_T, W_T, narrow, 36864) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 43264) \
|
||||||
f(in_T, out_T, W_T, narrow, 49152) \
|
f(in_T, out_T, W_T, narrow, 49152) \
|
||||||
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
|
f(in_T, out_T, W_T, narrow, 64000) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 64256) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 64512) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 102400) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 102656) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 102912) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 128000) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 128256) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 128512) \
|
||||||
|
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
|
||||||
|
// and vllm/tests/lora/test_punica.py
|
||||||
|
|
||||||
// Keep this in sync with vllm/config::LoRAConfig
|
// Keep this in sync with vllm/config::LoRAConfig
|
||||||
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
|
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
|
||||||
|
|||||||
@@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16)
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16)
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
#include "bgmv_config.h"
|
|
||||||
#include "bgmv_impl.cuh"
|
|
||||||
|
|
||||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half)
|
|
||||||
@@ -18,6 +18,26 @@ for input_dtype in DTYPES:
|
|||||||
if weight_dtype == "fp32":
|
if weight_dtype == "fp32":
|
||||||
# FP32 weights are not supported.
|
# FP32 weights are not supported.
|
||||||
continue
|
continue
|
||||||
|
if output_dtype == "fp32":
|
||||||
|
# LoRA A matrix.
|
||||||
|
if input_dtype != weight_dtype:
|
||||||
|
# NOTE(woosuk): While Punica supports the case where the
|
||||||
|
# input and weight dtypes are different, we only generate
|
||||||
|
# the kernels the same dtypes to reduce the binary size.
|
||||||
|
continue
|
||||||
|
elif input_dtype == "fp32":
|
||||||
|
# LoRA B matrix.
|
||||||
|
if output_dtype != weight_dtype:
|
||||||
|
# NOTE(woosuk): While Punica supports the case where the
|
||||||
|
# output and weight dtypes are different, we only generate
|
||||||
|
# the kernels the same dtypes to reduce the binary size.
|
||||||
|
continue
|
||||||
|
elif not (input_dtype == output_dtype == weight_dtype):
|
||||||
|
# NOTE(woosuk): While Punica supports mixed data types for
|
||||||
|
# input, output, and weight, we only generate the kernels with
|
||||||
|
# the same data types to reduce the binary size.
|
||||||
|
continue
|
||||||
|
|
||||||
kernel_definition = TEMPLATE.format(
|
kernel_definition = TEMPLATE.format(
|
||||||
input_dtype=DTYPE_MAP[input_dtype],
|
input_dtype=DTYPE_MAP[input_dtype],
|
||||||
output_dtype=DTYPE_MAP[output_dtype],
|
output_dtype=DTYPE_MAP[output_dtype],
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
|
inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) {
|
||||||
return (uint32_t(a) << 16) | uint32_t(b);
|
return (uint64_t(a) << 32) | uint64_t(b);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||||
@@ -46,13 +46,30 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
|
|||||||
template <typename in_T, typename out_T, typename W_T>
|
template <typename in_T, typename out_T, typename W_T>
|
||||||
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
|
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
|
||||||
const int64_t *lora_indices,
|
const int64_t *lora_indices,
|
||||||
uint16_t in_features, uint16_t out_features,
|
uint32_t in_features, uint32_t out_features,
|
||||||
int64_t y_offset, int64_t full_y_size,
|
int64_t y_offset, int64_t full_y_size,
|
||||||
int64_t batch_size, int64_t num_layers,
|
int64_t batch_size, int64_t num_layers,
|
||||||
int64_t layer_idx, float scale) {
|
int64_t layer_idx, float scale) {
|
||||||
switch (pack_u16(in_features, out_features)) {
|
// NOTE(woosuk): While Punica supports various combinations of input/output
|
||||||
|
// data types, we limit the supported data types to reduce the binary size.
|
||||||
|
constexpr bool is_input_float = std::is_same<in_T, float>::value;
|
||||||
|
constexpr bool is_output_float = std::is_same<out_T, float>::value;
|
||||||
|
if (is_input_float) {
|
||||||
|
if (!std::is_same<out_T, W_T>::value) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else if (is_output_float) {
|
||||||
|
if (!std::is_same<in_T, W_T>::value) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else if (!(std::is_same<in_T, W_T>::value &&
|
||||||
|
std::is_same<out_T, W_T>::value)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (pack_u32(in_features, out_features)) {
|
||||||
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
|
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
|
||||||
case pack_u16(feat_in, feat_out): \
|
case pack_u32(feat_in, feat_out): \
|
||||||
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
|
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
|
||||||
full_y_size, batch_size, num_layers, \
|
full_y_size, batch_size, num_layers, \
|
||||||
layer_idx, scale); \
|
layer_idx, scale); \
|
||||||
@@ -93,7 +110,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
|||||||
CHECK_EQ(y.size(0), x.size(0));
|
CHECK_EQ(y.size(0), x.size(0));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||||
bool ok = false;
|
bool ok = false;
|
||||||
if (h_in < 65536 && h_out < 65536) {
|
if (h_in <= 128512 && h_out <= 128512) {
|
||||||
// TODO: See if we can get rid of this massive nested switch
|
// TODO: See if we can get rid of this massive nested switch
|
||||||
switch (x.scalar_type()) {
|
switch (x.scalar_type()) {
|
||||||
case at::ScalarType::Half:
|
case at::ScalarType::Half:
|
||||||
@@ -325,7 +342,7 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
|||||||
CHECK_EQ(y.size(0), x.size(0));
|
CHECK_EQ(y.size(0), x.size(0));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||||
bool ok = false;
|
bool ok = false;
|
||||||
if (h_in < 65536 && h_out < 65536) {
|
if (h_in <= 128512 && h_out <= 128512) {
|
||||||
// TODO: See if we can get rid of this massive nested switch
|
// TODO: See if we can get rid of this massive nested switch
|
||||||
switch (x.scalar_type()) {
|
switch (x.scalar_type()) {
|
||||||
case at::ScalarType::Half:
|
case at::ScalarType::Half:
|
||||||
|
|||||||
@@ -63,6 +63,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
|
|
||||||
// Quantization ops
|
// Quantization ops
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
|
ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
|
||||||
|
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
|
||||||
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||||
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
|
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
|
||||||
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
||||||
@@ -71,6 +73,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
||||||
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
||||||
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||||
|
ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
|
||||||
ops.def(
|
ops.def(
|
||||||
"moe_align_block_size",
|
"moe_align_block_size",
|
||||||
&moe_align_block_size,
|
&moe_align_block_size,
|
||||||
@@ -91,9 +94,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
&reshape_and_cache,
|
&reshape_and_cache,
|
||||||
"Reshape the key and value tensors and cache them");
|
"Reshape the key and value tensors and cache them");
|
||||||
cache_ops.def(
|
cache_ops.def(
|
||||||
"convert_fp8_e5m2",
|
"convert_fp8",
|
||||||
&convert_fp8_e5m2,
|
&convert_fp8,
|
||||||
"Convert the key and value cache to fp8_e5m2 data type");
|
"Convert the key and value cache to fp8 data type");
|
||||||
|
|
||||||
// Cuda utils
|
// Cuda utils
|
||||||
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
|
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
|
||||||
|
|||||||
712
csrc/quantization/aqlm/gemm_kernels.cu
Normal file
712
csrc/quantization/aqlm/gemm_kernels.cu
Normal file
@@ -0,0 +1,712 @@
|
|||||||
|
/*
|
||||||
|
* Modified by Neural Magic
|
||||||
|
* Adapted from https://github.com/Vahe1994/AQLM
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <c10/cuda/CUDAStream.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace aqlm {
|
||||||
|
|
||||||
|
__global__ void Code1x16MatVec(
|
||||||
|
const int4* __restrict__ A,
|
||||||
|
const int4* __restrict__ B,
|
||||||
|
int4* __restrict__ C,
|
||||||
|
const int4* __restrict__ codebook,
|
||||||
|
const int prob_m,
|
||||||
|
const int prob_k,
|
||||||
|
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
||||||
|
const int codebook_stride // as int4.
|
||||||
|
) {
|
||||||
|
int a_gl_stride = prob_k / 8 / 8;
|
||||||
|
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||||
|
bool pred = a_gl_rd < prob_m;
|
||||||
|
|
||||||
|
if (pred)
|
||||||
|
{
|
||||||
|
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
||||||
|
auto codebook_size = &codebook_a_sizes.x;
|
||||||
|
while (a_gl_rd >= *codebook_size)
|
||||||
|
{
|
||||||
|
codebook += codebook_stride;
|
||||||
|
++codebook_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int b_gl_rd = 0;
|
||||||
|
int c_gl_wr = a_gl_rd;
|
||||||
|
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
|
||||||
|
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
|
||||||
|
|
||||||
|
__shared__ int4 sh_b[32 * 9];
|
||||||
|
float res = 0;
|
||||||
|
|
||||||
|
int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
|
||||||
|
while (iters--) {
|
||||||
|
// We pad shared memory to avoid bank conflicts during reads
|
||||||
|
__syncthreads();
|
||||||
|
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
|
||||||
|
if (b_gl_rd + i < prob_k / 8)
|
||||||
|
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
b_gl_rd += 32 * 8;
|
||||||
|
|
||||||
|
int b_sh_rd = 9 * (threadIdx.x % 32);
|
||||||
|
if (pred && a_gl_rd < a_gl_end) {
|
||||||
|
const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
uint32_t dec[4];
|
||||||
|
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
|
||||||
|
// actually help us; this brings > 2x speedup.
|
||||||
|
asm volatile (
|
||||||
|
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
||||||
|
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
||||||
|
: "l"((void*) &codebook[enc[i]])
|
||||||
|
);
|
||||||
|
half2* a = reinterpret_cast<half2*>(&dec);
|
||||||
|
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
||||||
|
half2 res2 = {};
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 4; j++)
|
||||||
|
res2 = __hfma2(a[j], b[j], res2);
|
||||||
|
res += __half2float(res2.x) + __half2float(res2.y);
|
||||||
|
b_sh_rd++;
|
||||||
|
}
|
||||||
|
a_gl_rd += 32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pred) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 16; i > 0; i /= 2)
|
||||||
|
res += __shfl_down_sync(0xffffffff, res, i);
|
||||||
|
if (threadIdx.x % 32 == 0)
|
||||||
|
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void Code2x8MatVec(
|
||||||
|
const int4* __restrict__ A,
|
||||||
|
const int4* __restrict__ B,
|
||||||
|
int4* __restrict__ C,
|
||||||
|
const int4* __restrict__ codebook,
|
||||||
|
int prob_m,
|
||||||
|
int prob_k,
|
||||||
|
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
||||||
|
const int codebook_stride // as int4.
|
||||||
|
|
||||||
|
) {
|
||||||
|
int a_gl_stride = prob_k / 8 / 8;
|
||||||
|
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||||
|
bool pred = a_gl_rd < prob_m;
|
||||||
|
|
||||||
|
if (pred)
|
||||||
|
{
|
||||||
|
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
||||||
|
auto codebook_size = &codebook_a_sizes.x;
|
||||||
|
while (a_gl_rd >= *codebook_size)
|
||||||
|
{
|
||||||
|
codebook += codebook_stride;
|
||||||
|
++codebook_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int b_gl_rd = 0;
|
||||||
|
int c_gl_wr = a_gl_rd;
|
||||||
|
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
|
||||||
|
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
|
||||||
|
int lane = threadIdx.x % 8;
|
||||||
|
|
||||||
|
extern __shared__ int4 sh[];
|
||||||
|
int4* sh_b = sh;
|
||||||
|
int4* sh_code = sh_b + 32 * 9;
|
||||||
|
int4* sh_code0 = sh_code;
|
||||||
|
int4* sh_code1 = sh_code + 256 * 8;
|
||||||
|
|
||||||
|
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
|
||||||
|
int4 dec = codebook[i];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 8; j++)
|
||||||
|
sh_code[8 * i + (j + lane) % 8] = dec;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float res = 0;
|
||||||
|
|
||||||
|
int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
|
||||||
|
while (iters--) {
|
||||||
|
// We pad shared memory to avoid bank conflicts during reads
|
||||||
|
__syncthreads();
|
||||||
|
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
|
||||||
|
if (b_gl_rd + i < prob_k / 8)
|
||||||
|
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
b_gl_rd += 32 * 8;
|
||||||
|
|
||||||
|
int b_sh_rd = 9 * (threadIdx.x % 32);
|
||||||
|
if (pred && a_gl_rd < a_gl_end) {
|
||||||
|
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
||||||
|
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
||||||
|
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
||||||
|
half2 res2 = {};
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 4; j++)
|
||||||
|
res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2);
|
||||||
|
res += __half2float(res2.x) + __half2float(res2.y);
|
||||||
|
b_sh_rd++;
|
||||||
|
}
|
||||||
|
a_gl_rd += 32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pred) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 16; i > 0; i /= 2)
|
||||||
|
res += __shfl_down_sync(0xffffffff, res, i);
|
||||||
|
if (threadIdx.x % 32 == 0)
|
||||||
|
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__global__ void Code1x16Dequant(
|
||||||
|
const int4* __restrict__ A,
|
||||||
|
int4* __restrict__ C,
|
||||||
|
const int4* __restrict__ codebook,
|
||||||
|
int prob_m,
|
||||||
|
int prob_k,
|
||||||
|
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, sums to m.
|
||||||
|
const int codebook_stride // as int4
|
||||||
|
) {
|
||||||
|
int a_gl_stride = prob_k / 8 / 8;
|
||||||
|
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||||
|
bool pred = a_gl_rd < prob_m;
|
||||||
|
|
||||||
|
if (pred)
|
||||||
|
{
|
||||||
|
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
||||||
|
auto codebook_size = &codebook_a_sizes.x;
|
||||||
|
while (a_gl_rd >= *codebook_size)
|
||||||
|
{
|
||||||
|
codebook += codebook_stride;
|
||||||
|
++codebook_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
|
||||||
|
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
|
||||||
|
|
||||||
|
int c_gl_stride = prob_k / 8;
|
||||||
|
int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||||
|
c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8;
|
||||||
|
|
||||||
|
int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
|
||||||
|
while (iters--) {
|
||||||
|
if (pred && a_gl_rd < a_gl_end) {
|
||||||
|
const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
int4 chunk;
|
||||||
|
auto dec = reinterpret_cast<uint32_t*>(&chunk);
|
||||||
|
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
|
||||||
|
// actually help us; this brings > 2x speedup.
|
||||||
|
asm volatile (
|
||||||
|
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
||||||
|
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
||||||
|
: "l"((void*) &codebook[enc[i]])
|
||||||
|
);
|
||||||
|
|
||||||
|
C[a_gl_rd * 8 + i] = chunk;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
a_gl_rd += 32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__global__ void Code2x8Dequant(
|
||||||
|
const int4* __restrict__ A,
|
||||||
|
int4* __restrict__ C,
|
||||||
|
const int4* __restrict__ codebook,
|
||||||
|
int prob_m,
|
||||||
|
int prob_k,
|
||||||
|
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
|
||||||
|
const int codebook_stride // as int4
|
||||||
|
) {
|
||||||
|
int a_gl_stride = prob_k / 8 / 8;
|
||||||
|
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||||
|
bool pred = a_gl_rd < prob_m;
|
||||||
|
|
||||||
|
if (pred)
|
||||||
|
{
|
||||||
|
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
||||||
|
auto codebook_size = &codebook_a_sizes.x;
|
||||||
|
while (a_gl_rd >= *codebook_size)
|
||||||
|
{
|
||||||
|
codebook += codebook_stride;
|
||||||
|
++codebook_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
|
||||||
|
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
|
||||||
|
int lane = threadIdx.x % 8;
|
||||||
|
|
||||||
|
int c_gl_stride = prob_k / 8;
|
||||||
|
int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||||
|
c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8;
|
||||||
|
|
||||||
|
extern __shared__ int4 sh[];
|
||||||
|
int4* sh_code = sh;
|
||||||
|
int4* sh_code0 = sh_code;
|
||||||
|
int4* sh_code1 = sh_code + 256 * 8;
|
||||||
|
|
||||||
|
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
|
||||||
|
int4 dec = codebook[i];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 8; j++)
|
||||||
|
sh_code[8 * i + (j + lane) % 8] = dec;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float res = 0;
|
||||||
|
|
||||||
|
int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
|
||||||
|
while (iters--) {
|
||||||
|
if (pred && a_gl_rd < a_gl_end) {
|
||||||
|
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
int4 chunk;
|
||||||
|
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
||||||
|
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 4; j++)
|
||||||
|
reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]);
|
||||||
|
C[a_gl_rd * 8 + i] = chunk;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
a_gl_rd += 32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int ceildiv(int a, int b) {
|
||||||
|
return (a + b - 1) / b;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int THREAD_M = 16;
|
||||||
|
|
||||||
|
void code1x16_matvec_cuda(
|
||||||
|
const void* __restrict__ A,
|
||||||
|
const void* __restrict__ B,
|
||||||
|
void* __restrict__ C,
|
||||||
|
const void* __restrict__ codebook,
|
||||||
|
int prob_m,
|
||||||
|
int prob_k,
|
||||||
|
const int4 codebook_a_sizes,
|
||||||
|
const int codebook_stride
|
||||||
|
) {
|
||||||
|
int sms;
|
||||||
|
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
||||||
|
int waves = 0;
|
||||||
|
int thread_m;
|
||||||
|
do {
|
||||||
|
waves++;
|
||||||
|
thread_m = ceildiv(prob_m, waves * sms);
|
||||||
|
} while (thread_m > THREAD_M);
|
||||||
|
|
||||||
|
int blocks = ceildiv(prob_m, thread_m);
|
||||||
|
int threads = 32 * thread_m;
|
||||||
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
|
Code1x16MatVec<<<blocks, threads, 16*32*9, stream>>>(
|
||||||
|
(const int4*) A,
|
||||||
|
(const int4*) B,
|
||||||
|
(int4*) C,
|
||||||
|
(const int4*) codebook,
|
||||||
|
prob_m,
|
||||||
|
prob_k,
|
||||||
|
codebook_a_sizes,
|
||||||
|
codebook_stride
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
void code2x8_matvec_cuda(
|
||||||
|
const void* __restrict__ A,
|
||||||
|
const void* __restrict__ B,
|
||||||
|
void* __restrict__ C,
|
||||||
|
const void* __restrict__ codebook,
|
||||||
|
int prob_m,
|
||||||
|
int prob_k,
|
||||||
|
const int4 codebook_a_sizes,
|
||||||
|
const int codebook_stride
|
||||||
|
) {
|
||||||
|
int sms;
|
||||||
|
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
||||||
|
int waves = 0;
|
||||||
|
int thread_m;
|
||||||
|
do {
|
||||||
|
waves++;
|
||||||
|
thread_m = ceildiv(prob_m, waves * sms);
|
||||||
|
} while (thread_m > THREAD_M);
|
||||||
|
|
||||||
|
int blocks = ceildiv(prob_m, thread_m);
|
||||||
|
int threads = 32 * thread_m;
|
||||||
|
int shared = 16 * (2 * 256 * 8 + 32 * 9);
|
||||||
|
cudaFuncSetAttribute(
|
||||||
|
Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
|
||||||
|
);
|
||||||
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
|
Code2x8MatVec<<<blocks, threads, shared, stream>>>(
|
||||||
|
(const int4*) A,
|
||||||
|
(const int4*) B,
|
||||||
|
(int4*) C,
|
||||||
|
(const int4*) codebook,
|
||||||
|
prob_m,
|
||||||
|
prob_k,
|
||||||
|
codebook_a_sizes,
|
||||||
|
codebook_stride
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
void code1x16_dequant_cuda(
|
||||||
|
const void* __restrict__ A,
|
||||||
|
void* __restrict__ C,
|
||||||
|
const void* __restrict__ codebook,
|
||||||
|
int prob_m,
|
||||||
|
int prob_k,
|
||||||
|
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
||||||
|
const int codebook_stride // as int4.
|
||||||
|
) {
|
||||||
|
int sms;
|
||||||
|
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
||||||
|
int waves = 0;
|
||||||
|
int thread_m;
|
||||||
|
do {
|
||||||
|
waves++;
|
||||||
|
thread_m = ceildiv(prob_m, waves * sms);
|
||||||
|
} while (thread_m > THREAD_M);
|
||||||
|
|
||||||
|
int blocks = ceildiv(prob_m, thread_m);
|
||||||
|
int threads = 32 * thread_m;
|
||||||
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
|
Code1x16Dequant<<<blocks, threads, 0, stream>>>(
|
||||||
|
(const int4*) A,
|
||||||
|
(int4*) C,
|
||||||
|
(const int4*) codebook,
|
||||||
|
prob_m,
|
||||||
|
prob_k,
|
||||||
|
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
||||||
|
codebook_stride // as int4.
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dequantizes the code and codebook into weights.
|
||||||
|
void code2x8_dequant_cuda(
|
||||||
|
const void* __restrict__ A,
|
||||||
|
void* __restrict__ C,
|
||||||
|
const void* __restrict__ codebook,
|
||||||
|
int prob_m,
|
||||||
|
int prob_k,
|
||||||
|
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
|
||||||
|
const int codebook_stride // as int4
|
||||||
|
) {
|
||||||
|
int sms;
|
||||||
|
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
||||||
|
int waves = 0;
|
||||||
|
int thread_m;
|
||||||
|
do {
|
||||||
|
waves++;
|
||||||
|
thread_m = ceildiv(prob_m, waves * sms);
|
||||||
|
} while (thread_m > THREAD_M);
|
||||||
|
|
||||||
|
int blocks = ceildiv(prob_m, thread_m);
|
||||||
|
int threads = 32 * thread_m;
|
||||||
|
int shared = 16 * (2 * 256 * 8 + 32 * 9);
|
||||||
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
|
|
||||||
|
cudaFuncSetAttribute(
|
||||||
|
Code2x8Dequant, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
|
||||||
|
);
|
||||||
|
Code2x8Dequant<<<blocks, threads, shared, stream>>>(
|
||||||
|
(const int4*) A,
|
||||||
|
(int4*) C,
|
||||||
|
(const int4*) codebook,
|
||||||
|
prob_m,
|
||||||
|
prob_k,
|
||||||
|
codebook_a_sizes,
|
||||||
|
codebook_stride
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
int codebook_stride(const torch::Tensor& codebooks)
|
||||||
|
{
|
||||||
|
return codebooks.stride(0) * codebooks.element_size() / sizeof(int4);
|
||||||
|
}
|
||||||
|
|
||||||
|
void code1x16_matvec(
|
||||||
|
const torch::Tensor& A,
|
||||||
|
const torch::Tensor& B,
|
||||||
|
torch::Tensor& C,
|
||||||
|
const torch::Tensor& codebook,
|
||||||
|
const int4 codebook_a_sizes // cumulative sizes of A spanning each codebook, at most 3 long.
|
||||||
|
) {
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||||
|
int prob_m = C.size(0);
|
||||||
|
int prob_k = B.size(0);
|
||||||
|
|
||||||
|
code1x16_matvec_cuda(
|
||||||
|
A.data_ptr(),
|
||||||
|
B.data_ptr(),
|
||||||
|
C.data_ptr(),
|
||||||
|
codebook.data_ptr(),
|
||||||
|
prob_m,
|
||||||
|
prob_k,
|
||||||
|
codebook_a_sizes,
|
||||||
|
codebook_stride(codebook)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor code1x16_matmat(
|
||||||
|
const torch::Tensor& input,
|
||||||
|
const torch::Tensor& codes,
|
||||||
|
const torch::Tensor& codebooks,
|
||||||
|
const torch::Tensor& scales,
|
||||||
|
const int4 codebook_a_sizes,
|
||||||
|
const std::optional<torch::Tensor>& bias) {
|
||||||
|
auto input_sizes = input.sizes();
|
||||||
|
auto out_features = codes.size(0) * codebooks.size(2);
|
||||||
|
auto flat_input = input.reshape({-1, input.size(-1)});
|
||||||
|
auto flat_output = torch::empty({flat_input.size(0), out_features},
|
||||||
|
torch::TensorOptions()
|
||||||
|
.dtype(input.dtype())
|
||||||
|
.device(input.device())
|
||||||
|
);
|
||||||
|
|
||||||
|
for (int i = 0; i < flat_input.size(0); ++i) {
|
||||||
|
auto input_vec = flat_input.index({i});
|
||||||
|
auto output_vec = flat_output.index({i});
|
||||||
|
code1x16_matvec(
|
||||||
|
codes.squeeze(2),
|
||||||
|
input_vec,
|
||||||
|
output_vec,
|
||||||
|
codebooks,
|
||||||
|
codebook_a_sizes
|
||||||
|
);
|
||||||
|
}
|
||||||
|
flat_output *= scales.flatten().unsqueeze(0);
|
||||||
|
|
||||||
|
if (bias.has_value()) {
|
||||||
|
flat_output += bias->unsqueeze(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output_sizes = input_sizes.vec();
|
||||||
|
output_sizes.pop_back();
|
||||||
|
output_sizes.push_back(-1);
|
||||||
|
auto output = flat_output.reshape(output_sizes);
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
void code2x8_matvec(
|
||||||
|
const torch::Tensor& A,
|
||||||
|
const torch::Tensor& B,
|
||||||
|
torch::Tensor& C,
|
||||||
|
const torch::Tensor& codebook,
|
||||||
|
const int4 codebook_a_sizes
|
||||||
|
) {
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||||
|
int prob_m = C.size(0);
|
||||||
|
int prob_k = B.size(0);
|
||||||
|
code2x8_matvec_cuda(
|
||||||
|
A.data_ptr(),
|
||||||
|
B.data_ptr(),
|
||||||
|
C.data_ptr(),
|
||||||
|
codebook.data_ptr(),
|
||||||
|
prob_m,
|
||||||
|
prob_k,
|
||||||
|
codebook_a_sizes,
|
||||||
|
2 * codebook_stride(codebook)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor code2x8_matmat(
|
||||||
|
const torch::Tensor& input,
|
||||||
|
const torch::Tensor& codes,
|
||||||
|
const torch::Tensor& codebooks,
|
||||||
|
const torch::Tensor& scales,
|
||||||
|
const int4 codebook_a_sizes,
|
||||||
|
const std::optional<torch::Tensor>& bias
|
||||||
|
) {
|
||||||
|
auto input_sizes = input.sizes();
|
||||||
|
auto out_features = codes.size(0) * codebooks.size(2);
|
||||||
|
auto flat_input = input.reshape({-1, input.size(-1)});
|
||||||
|
auto flat_output = torch::empty({flat_input.size(0), out_features},
|
||||||
|
torch::TensorOptions()
|
||||||
|
.dtype(input.dtype())
|
||||||
|
.device(input.device())
|
||||||
|
);
|
||||||
|
|
||||||
|
for (int i = 0; i < flat_input.size(0); ++i) {
|
||||||
|
auto input_vec = flat_input.index({i});
|
||||||
|
auto output_vec = flat_output.index({i});
|
||||||
|
code2x8_matvec(
|
||||||
|
codes.squeeze(2),
|
||||||
|
input_vec,
|
||||||
|
output_vec,
|
||||||
|
codebooks,
|
||||||
|
codebook_a_sizes
|
||||||
|
);
|
||||||
|
}
|
||||||
|
flat_output *= scales.flatten().unsqueeze(0);
|
||||||
|
if (bias.has_value()) {
|
||||||
|
flat_output += bias->unsqueeze(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output_sizes = input_sizes.vec();
|
||||||
|
output_sizes.pop_back();
|
||||||
|
output_sizes.push_back(-1);
|
||||||
|
auto output = flat_output.reshape(output_sizes);
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accumulate the partition sizes.
|
||||||
|
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes)
|
||||||
|
{
|
||||||
|
int4 cumulative_sizes;
|
||||||
|
auto cumulative_size = &cumulative_sizes.x;
|
||||||
|
int i = 0;
|
||||||
|
int last = 0;
|
||||||
|
assert(codebook_partition_sizes.size(0) <= 4);
|
||||||
|
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size)
|
||||||
|
{
|
||||||
|
*cumulative_size = codebook_partition_sizes[i].item<int>() + last;
|
||||||
|
last = *cumulative_size;
|
||||||
|
}
|
||||||
|
// fill in the rest with unreachable.
|
||||||
|
for (; i < 4; ++i, ++cumulative_size)
|
||||||
|
{
|
||||||
|
*cumulative_size = last*10;
|
||||||
|
}
|
||||||
|
return cumulative_sizes;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace aqlm
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
|
||||||
|
torch::Tensor aqlm_gemm(
|
||||||
|
const torch::Tensor& input,
|
||||||
|
const torch::Tensor& codes,
|
||||||
|
const torch::Tensor& codebooks,
|
||||||
|
const torch::Tensor& scales,
|
||||||
|
const torch::Tensor& codebook_partition_sizes,
|
||||||
|
const std::optional<torch::Tensor>& bias
|
||||||
|
)
|
||||||
|
{
|
||||||
|
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||||
|
|
||||||
|
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
||||||
|
int const entries = codebooks.size(1);
|
||||||
|
|
||||||
|
if (nbooks == 1 && entries == (1 << 16))
|
||||||
|
{
|
||||||
|
return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
|
||||||
|
}
|
||||||
|
if (nbooks == 2 && entries == (1 << 8))
|
||||||
|
{
|
||||||
|
return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor aqlm_dequant(
|
||||||
|
const torch::Tensor& codes,
|
||||||
|
const torch::Tensor& codebooks,
|
||||||
|
const torch::Tensor& codebook_partition_sizes
|
||||||
|
)
|
||||||
|
{
|
||||||
|
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||||
|
|
||||||
|
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
||||||
|
int const entries = codebooks.size(1);
|
||||||
|
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(codes));
|
||||||
|
int rows = codes.size(1);
|
||||||
|
int cols = codes.size(0);
|
||||||
|
|
||||||
|
auto in_features = codes.size(1) * 8;
|
||||||
|
auto out_features = codes.size(0);
|
||||||
|
|
||||||
|
assert(out_features = codebook_partition_sizes.sum().item<int>());
|
||||||
|
|
||||||
|
auto weights = torch::empty({out_features, in_features},
|
||||||
|
torch::TensorOptions()
|
||||||
|
.dtype(codebooks.dtype())
|
||||||
|
.device(codebooks.device())
|
||||||
|
);
|
||||||
|
|
||||||
|
if (nbooks == 1 && entries == (1 << 16))
|
||||||
|
{
|
||||||
|
vllm::aqlm::code1x16_dequant_cuda(
|
||||||
|
codes.data_ptr(),
|
||||||
|
weights.data_ptr(),
|
||||||
|
codebooks.data_ptr(),
|
||||||
|
out_features,
|
||||||
|
in_features,
|
||||||
|
cumulative_sizes,
|
||||||
|
vllm::aqlm::codebook_stride(codebooks));
|
||||||
|
|
||||||
|
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.)
|
||||||
|
// weights *= scales.index({"...", 0, 0});
|
||||||
|
|
||||||
|
return weights;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nbooks == 2 && entries == (1 << 8))
|
||||||
|
{
|
||||||
|
vllm::aqlm::code2x8_dequant_cuda(
|
||||||
|
codes.data_ptr(),
|
||||||
|
weights.data_ptr(),
|
||||||
|
codebooks.data_ptr(),
|
||||||
|
out_features,
|
||||||
|
in_features,
|
||||||
|
cumulative_sizes,
|
||||||
|
vllm::aqlm::codebook_stride(codebooks));
|
||||||
|
|
||||||
|
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation)
|
||||||
|
// weights *= scales.index({"...", 0, 0});
|
||||||
|
|
||||||
|
return weights;
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
|
||||||
|
return {};
|
||||||
|
}
|
||||||
167
csrc/quantization/fp8/amd_detail/hip_float8.h
Normal file
167
csrc/quantization/fp8/amd_detail/hip_float8.h
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#ifdef __HIPCC__
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
#else
|
||||||
|
#include <type_traits>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include <iostream>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "hip_float8_impl.h"
|
||||||
|
|
||||||
|
struct alignas(1) hip_fp8
|
||||||
|
{
|
||||||
|
struct from_bits_t
|
||||||
|
{
|
||||||
|
};
|
||||||
|
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); }
|
||||||
|
uint8_t data;
|
||||||
|
|
||||||
|
hip_fp8() = default;
|
||||||
|
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
|
||||||
|
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
|
||||||
|
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
|
||||||
|
: data(v)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef __HIP__MI300__
|
||||||
|
// NOTE: ON-DEVICE... always optimal bias
|
||||||
|
explicit HIP_FP8_DEVICE hip_fp8(float v)
|
||||||
|
: data(hip_fp8_impl::to_fp8_from_fp32(v))
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
|
||||||
|
: hip_fp8(static_cast<float>(v))
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
// Host only implementation using s/w simulation
|
||||||
|
explicit HIP_FP8_HOST
|
||||||
|
#else // __HIP__MI300__
|
||||||
|
// both Host and DEVICE for non-MI300 using s/w simulation
|
||||||
|
explicit HIP_FP8_HOST_DEVICE
|
||||||
|
#endif // __HIP__MI300__
|
||||||
|
hip_fp8(float v)
|
||||||
|
{
|
||||||
|
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, true /*clip*/>(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
|
||||||
|
: hip_fp8(static_cast<float>(v))
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef __HIP__MI300__
|
||||||
|
// upcast using device specific intrinsic
|
||||||
|
explicit inline HIP_FP8_DEVICE operator float() const
|
||||||
|
{
|
||||||
|
float fval;
|
||||||
|
uint32_t i32val = static_cast<uint32_t>(data);
|
||||||
|
|
||||||
|
// upcast
|
||||||
|
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
|
||||||
|
|
||||||
|
return fval;
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit inline HIP_FP8_HOST operator float() const
|
||||||
|
#else // __HIP__MI300__
|
||||||
|
explicit inline HIP_FP8_HOST_DEVICE operator float() const
|
||||||
|
#endif // __HIP__MI300__
|
||||||
|
{
|
||||||
|
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(data);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace std
|
||||||
|
{
|
||||||
|
inline hip_fp8 sin(hip_fp8 a)
|
||||||
|
{
|
||||||
|
return hip_fp8(sinf(float(a)));
|
||||||
|
}
|
||||||
|
inline hip_fp8 cos(hip_fp8 a)
|
||||||
|
{
|
||||||
|
return hip_fp8(cosf(float(a)));
|
||||||
|
}
|
||||||
|
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a)
|
||||||
|
{
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
} // namespace std
|
||||||
|
|
||||||
|
// Special operator overloading
|
||||||
|
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8)
|
||||||
|
{
|
||||||
|
return os << float(f8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// all + operator overloading with mixed types
|
||||||
|
// mixed types, always converts to f32, does computation in f32, and returns float
|
||||||
|
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b)
|
||||||
|
{
|
||||||
|
return (fa + float(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb)
|
||||||
|
{
|
||||||
|
return (float(a) + fb);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b)
|
||||||
|
{
|
||||||
|
return hip_fp8(float(a) + float(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b)
|
||||||
|
{
|
||||||
|
return a = hip_fp8(float(a) + float(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
// overloading multiplication, always returns float,
|
||||||
|
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b)
|
||||||
|
{
|
||||||
|
return float(a) * float(b);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b)
|
||||||
|
{
|
||||||
|
return (a * float(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b)
|
||||||
|
{
|
||||||
|
return (float(a) * b);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b)
|
||||||
|
{
|
||||||
|
return ((float)a * float(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b)
|
||||||
|
{
|
||||||
|
return ((float)a * float(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
// overloading for compare
|
||||||
|
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b)
|
||||||
|
{
|
||||||
|
return (a.data == b.data);
|
||||||
|
}
|
||||||
|
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b)
|
||||||
|
{
|
||||||
|
return (a.data != b.data);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b)
|
||||||
|
{
|
||||||
|
return static_cast<float>(a) >= static_cast<float>(b);
|
||||||
|
}
|
||||||
|
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b)
|
||||||
|
{
|
||||||
|
return static_cast<float>(a) > static_cast<float>(b);
|
||||||
|
}
|
||||||
316
csrc/quantization/fp8/amd_detail/hip_float8_impl.h
Normal file
316
csrc/quantization/fp8/amd_detail/hip_float8_impl.h
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#if defined(__HIPCC__) && (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||||
|
#define __HIP__MI300__
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef __HIPCC__
|
||||||
|
#define HIP_FP8_HOST_DEVICE __host__ __device__
|
||||||
|
#define HIP_FP8_HOST __host__
|
||||||
|
#define HIP_FP8_DEVICE __device__
|
||||||
|
#else
|
||||||
|
#define HIP_FP8_HOST_DEVICE
|
||||||
|
#define HIP_FP8_HOST
|
||||||
|
#define HIP_FP8_DEVICE
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace hip_fp8_impl
|
||||||
|
{
|
||||||
|
|
||||||
|
#ifdef __HIP__MI300__
|
||||||
|
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v)
|
||||||
|
{
|
||||||
|
uint8_t i8data;
|
||||||
|
union {
|
||||||
|
float fval;
|
||||||
|
uint32_t i32val;
|
||||||
|
uint8_t i8val[4]; // NOTE: not endian independent
|
||||||
|
} val;
|
||||||
|
|
||||||
|
uint32_t ival = 0;
|
||||||
|
val.fval = v;
|
||||||
|
|
||||||
|
if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping
|
||||||
|
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
|
||||||
|
false); // false -> WORD0
|
||||||
|
val.i32val = ival;
|
||||||
|
i8data = val.i8val[0];
|
||||||
|
|
||||||
|
return i8data;
|
||||||
|
}
|
||||||
|
#endif // __HIP__MI300__
|
||||||
|
|
||||||
|
HIP_FP8_HOST inline int clz(uint32_t x)
|
||||||
|
{
|
||||||
|
return __builtin_clz(x);
|
||||||
|
}
|
||||||
|
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
|
||||||
|
HIP_FP8_DEVICE inline int clz(uint32_t x)
|
||||||
|
{
|
||||||
|
return __clz(x);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
|
||||||
|
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0)
|
||||||
|
{
|
||||||
|
#ifdef __HIPCC__
|
||||||
|
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||||
|
#else
|
||||||
|
constexpr bool is_half = false;
|
||||||
|
#endif
|
||||||
|
constexpr bool is_float = std::is_same<T, float>::value;
|
||||||
|
static_assert(wm + we == 7, "wm+we==7");
|
||||||
|
static_assert(is_half || is_float, "Only half and float can be cast to f8");
|
||||||
|
|
||||||
|
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
|
||||||
|
uint32_t x;
|
||||||
|
if (sizeof(T) == 4) {
|
||||||
|
x = reinterpret_cast<uint32_t&>(_x);
|
||||||
|
} else {
|
||||||
|
x = reinterpret_cast<uint16_t&>(_x);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t head, mantissa;
|
||||||
|
int exponent, bias;
|
||||||
|
uint32_t sign;
|
||||||
|
|
||||||
|
if (sizeof(T) == 4) {
|
||||||
|
head = x & 0xFF800000;
|
||||||
|
mantissa = x & 0x7FFFFF;
|
||||||
|
exponent = (head >> 23) & 0xFF;
|
||||||
|
sign = head >> 31;
|
||||||
|
bias = 127;
|
||||||
|
} else {
|
||||||
|
head = x & 0xFC00;
|
||||||
|
mantissa = x & 0x3FF;
|
||||||
|
exponent = (head >> 10) & 0x1F;
|
||||||
|
sign = head >> 15;
|
||||||
|
bias = 15;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
|
||||||
|
|
||||||
|
// Deal with inf and NaNs
|
||||||
|
if (negative_zero_nan) {
|
||||||
|
if (sizeof(T) == 4) {
|
||||||
|
if ((x & 0x7F800000) == 0x7F800000) {
|
||||||
|
return 0x80;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// if(__hisinf(x) || __hisnan(x))
|
||||||
|
if ((x & 0x7C00) == 0x7C00) {
|
||||||
|
return 0x80;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (sizeof(T) == 4) {
|
||||||
|
if ((x & 0x7F800000) == 0x7F800000) {
|
||||||
|
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if ((x & 0x7C00) == 0x7C00) {
|
||||||
|
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (x == 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// First need to check if it is normal or denorm as there is a difference of
|
||||||
|
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
|
||||||
|
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
|
||||||
|
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
|
||||||
|
// need to check whether there is carry and adjust exponent and mantissa again
|
||||||
|
|
||||||
|
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
|
||||||
|
// bits
|
||||||
|
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
||||||
|
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
|
||||||
|
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
||||||
|
// f8_exponent is the converted f8 exponent with bias encoding
|
||||||
|
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
||||||
|
// the difference needs to be adjusted and mantissa shifted
|
||||||
|
int act_exponent, f8_exponent, exponent_diff;
|
||||||
|
|
||||||
|
if (exponent == 0) { // fp32/fp16 is in denormal.
|
||||||
|
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
|
||||||
|
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
|
||||||
|
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
|
||||||
|
exponent bias 16. It means that there are some numbers in fp16 denormal but they
|
||||||
|
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
||||||
|
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
||||||
|
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
|
||||||
|
act_exponent = exponent - bias + 1;
|
||||||
|
exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
||||||
|
} else { // fp32/fp16 is normal with implicit 1
|
||||||
|
act_exponent = exponent - bias;
|
||||||
|
if (act_exponent <= f8_denormal_act_exponent) {
|
||||||
|
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
|
||||||
|
range. For example fp8 nanoo mode, denormal exponent is -7, but if the
|
||||||
|
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
|
||||||
|
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
|
||||||
|
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
||||||
|
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
||||||
|
} else { // both fp32/fp16 and f8 are in normal range
|
||||||
|
exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
|
||||||
|
// for this case,
|
||||||
|
// act_exponent could be larger. Just that it does not need shift mantissa
|
||||||
|
}
|
||||||
|
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
|
||||||
|
}
|
||||||
|
|
||||||
|
bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
|
||||||
|
static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
|
||||||
|
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
|
||||||
|
done before we shift right as shift right could rip off some residual part
|
||||||
|
and make something not midpoint look like midpoint. For example, the fp16
|
||||||
|
number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
|
||||||
|
shift right by 4 bits, it would look like midpoint.
|
||||||
|
*/
|
||||||
|
|
||||||
|
if (exponent_diff > 0) {
|
||||||
|
mantissa >>= exponent_diff;
|
||||||
|
} else if (exponent_diff == -1) {
|
||||||
|
mantissa <<= -exponent_diff;
|
||||||
|
}
|
||||||
|
bool implicit_one = mantissa & (1 << mfmt);
|
||||||
|
// if there is no implicit 1, it means the f8 is denormal and need to adjust
|
||||||
|
// to denorm exponent
|
||||||
|
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
|
||||||
|
|
||||||
|
// Now we have the exponent and mantissa adjusted
|
||||||
|
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
|
||||||
|
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit that
|
||||||
|
// is not truncated is 1
|
||||||
|
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
|
||||||
|
|
||||||
|
// Now we deal with overflow
|
||||||
|
if (f8_exponent == 0) {
|
||||||
|
if ((1 << mfmt) & mantissa) {
|
||||||
|
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if ((1 << (mfmt + 1)) & mantissa) {
|
||||||
|
mantissa >>= 1;
|
||||||
|
f8_exponent++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mantissa >>= (mfmt - wm);
|
||||||
|
|
||||||
|
// above range: quantize to maximum possible float of the same sign
|
||||||
|
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
|
||||||
|
if (f8_exponent > max_exp) {
|
||||||
|
if (clip) {
|
||||||
|
mantissa = (1 << wm) - 1;
|
||||||
|
f8_exponent = max_exp;
|
||||||
|
} else {
|
||||||
|
return signed_inf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (f8_exponent == 0 && mantissa == 0) {
|
||||||
|
return negative_zero_nan ? 0 : (sign << 7);
|
||||||
|
}
|
||||||
|
mantissa &= (1 << wm) - 1;
|
||||||
|
return (sign << 7) | (f8_exponent << wm) | mantissa;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int we, int wm, typename T = float, bool negative_zero_nan = true>
|
||||||
|
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x)
|
||||||
|
{
|
||||||
|
#ifdef __HIPCC__
|
||||||
|
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||||
|
#else
|
||||||
|
constexpr bool is_half = false;
|
||||||
|
#endif
|
||||||
|
constexpr bool is_float = std::is_same<T, float>::value;
|
||||||
|
static_assert(is_half || is_float, "only half and float are supported");
|
||||||
|
|
||||||
|
constexpr int weo = is_half ? 5 : 8;
|
||||||
|
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
|
||||||
|
|
||||||
|
T fInf, fNegInf, fNaN, fNeg0;
|
||||||
|
|
||||||
|
#ifdef __HIPCC__
|
||||||
|
if (is_half) {
|
||||||
|
const uint16_t ihInf = 0x7C00;
|
||||||
|
const uint16_t ihNegInf = 0xFC00;
|
||||||
|
const uint16_t ihNaN = 0x7C01;
|
||||||
|
const uint16_t ihNeg0 = 0x8000;
|
||||||
|
fInf = reinterpret_cast<const _Float16&>(ihInf);
|
||||||
|
fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
|
||||||
|
fNaN = reinterpret_cast<const _Float16&>(ihNaN);
|
||||||
|
fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
|
||||||
|
} else
|
||||||
|
#endif
|
||||||
|
if (is_float) {
|
||||||
|
const uint32_t ifInf = 0x7F800000;
|
||||||
|
const uint32_t ifNegInf = 0xFF800000;
|
||||||
|
const uint32_t ifNaN = 0x7F800001;
|
||||||
|
const uint32_t ifNeg0 = 0x80000000;
|
||||||
|
fInf = reinterpret_cast<const float&>(ifInf);
|
||||||
|
fNegInf = reinterpret_cast<const float&>(ifNegInf);
|
||||||
|
fNaN = reinterpret_cast<const float&>(ifNaN);
|
||||||
|
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (x == 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t sign = x >> 7;
|
||||||
|
uint32_t mantissa = x & ((1 << wm) - 1);
|
||||||
|
int exponent = (x & 0x7F) >> wm;
|
||||||
|
if (negative_zero_nan) {
|
||||||
|
if (x == 0x80) {
|
||||||
|
return fNaN;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (x == 0x80) {
|
||||||
|
return fNeg0;
|
||||||
|
}
|
||||||
|
if (exponent == ((1 << we) - 1)) {
|
||||||
|
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
|
||||||
|
if (we == 5 && is_half && !negative_zero_nan) {
|
||||||
|
retval = x << 8;
|
||||||
|
return reinterpret_cast<const T&>(retval);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||||
|
|
||||||
|
// subnormal input
|
||||||
|
if (exponent == 0) {
|
||||||
|
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
||||||
|
int sh = 1 + clz(mantissa) - (32 - wm);
|
||||||
|
mantissa <<= sh;
|
||||||
|
exponent += 1 - sh;
|
||||||
|
mantissa &= ((1 << wm) - 1);
|
||||||
|
}
|
||||||
|
exponent += exp_low_cutoff - 1;
|
||||||
|
mantissa <<= wmo - wm;
|
||||||
|
|
||||||
|
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
||||||
|
if (exponent <= 0) {
|
||||||
|
mantissa |= 1 << wmo;
|
||||||
|
mantissa >>= 1 - exponent;
|
||||||
|
exponent = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sizeof(T) == 2) {
|
||||||
|
retval = (sign << 15) | (exponent << 10) | mantissa;
|
||||||
|
} else {
|
||||||
|
retval = (sign << 31) | (exponent << 23) | mantissa;
|
||||||
|
}
|
||||||
|
return reinterpret_cast<const T&>(retval);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace hip_fp8_impl
|
||||||
517
csrc/quantization/fp8/amd_detail/quant_utils.cuh
Normal file
517
csrc/quantization/fp8/amd_detail/quant_utils.cuh
Normal file
@@ -0,0 +1,517 @@
|
|||||||
|
#pragma once
|
||||||
|
#include "hip_float8.h"
|
||||||
|
|
||||||
|
#include <hip/hip_fp16.h>
|
||||||
|
#include <hip/hip_bf16.h>
|
||||||
|
#include <hip/hip_bfloat16.h>
|
||||||
|
|
||||||
|
#include "../../../attention/dtype_float32.cuh"
|
||||||
|
#include "../../../attention/dtype_bfloat16.cuh"
|
||||||
|
|
||||||
|
namespace vllm
|
||||||
|
{
|
||||||
|
namespace fp8_e4m3 {
|
||||||
|
template <typename Tout, typename Tin>
|
||||||
|
__inline__ __device__ Tout vec_conversion(const Tin& x)
|
||||||
|
{
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Tout, typename Tin>
|
||||||
|
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, const float scale)
|
||||||
|
{
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> half
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
|
||||||
|
{
|
||||||
|
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||||
|
__half_raw res;
|
||||||
|
res.data = static_cast<float>(f8);
|
||||||
|
return res.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> half2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
|
||||||
|
{
|
||||||
|
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||||
|
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||||
|
union {
|
||||||
|
__half2_raw h2r;
|
||||||
|
uint32_t ui32;
|
||||||
|
} tmp;
|
||||||
|
tmp.h2r.x.data = f2[0];
|
||||||
|
tmp.h2r.y.data = f2[1];
|
||||||
|
return tmp.ui32;
|
||||||
|
#else
|
||||||
|
union {
|
||||||
|
uint16_t u16[2];
|
||||||
|
uint32_t u32;
|
||||||
|
} tmp;
|
||||||
|
|
||||||
|
tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
|
||||||
|
tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
||||||
|
return tmp.u32;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> half2x2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
|
||||||
|
{
|
||||||
|
union {
|
||||||
|
uint2 u32x2;
|
||||||
|
uint32_t u32[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
||||||
|
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
||||||
|
return tmp.u32x2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> half2x4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
|
||||||
|
{
|
||||||
|
union {
|
||||||
|
uint4 u64x2;
|
||||||
|
uint2 u64[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
||||||
|
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
||||||
|
return tmp.u64x2;
|
||||||
|
}
|
||||||
|
|
||||||
|
using __nv_bfloat16 = __hip_bfloat16;
|
||||||
|
|
||||||
|
// fp8 -> __nv_bfloat16
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
|
||||||
|
{
|
||||||
|
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||||
|
float f{f8};
|
||||||
|
return __float2bfloat16(f);
|
||||||
|
}
|
||||||
|
|
||||||
|
using __nv_bfloat162 = __hip_bfloat162;
|
||||||
|
|
||||||
|
// fp8x2 -> __nv_bfloat162
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
|
||||||
|
{
|
||||||
|
__nv_bfloat162 res;
|
||||||
|
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
||||||
|
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> bf16_4_t
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
|
||||||
|
{
|
||||||
|
bf16_4_t res;
|
||||||
|
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
||||||
|
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> bf16_8_t
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
|
||||||
|
{
|
||||||
|
bf16_4_t tmp1, tmp2;
|
||||||
|
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
||||||
|
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
||||||
|
bf16_8_t res;
|
||||||
|
res.x = tmp1.x;
|
||||||
|
res.y = tmp1.y;
|
||||||
|
res.z = tmp2.x;
|
||||||
|
res.w = tmp2.y;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> float
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
|
||||||
|
{
|
||||||
|
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
||||||
|
return static_cast<float>(fp8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> float2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
|
||||||
|
{
|
||||||
|
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||||
|
float2 res;
|
||||||
|
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||||
|
res.x = f2[0];
|
||||||
|
res.y = f2[1];
|
||||||
|
return res;
|
||||||
|
#else
|
||||||
|
float2 res;
|
||||||
|
res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
|
||||||
|
res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
||||||
|
return res;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> float4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
|
||||||
|
{
|
||||||
|
Float4_ res;
|
||||||
|
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
||||||
|
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> float8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
|
||||||
|
{
|
||||||
|
Float4_ tmp1, tmp2;
|
||||||
|
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
||||||
|
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
||||||
|
Float8_ res;
|
||||||
|
res.x = tmp1.x;
|
||||||
|
res.y = tmp1.y;
|
||||||
|
res.z = tmp2.x;
|
||||||
|
res.w = tmp2.y;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// half -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
|
||||||
|
{
|
||||||
|
__half_raw tmp;
|
||||||
|
tmp.x = a;
|
||||||
|
|
||||||
|
hip_fp8 f8{static_cast<float>(tmp.data)};
|
||||||
|
return f8.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
// bf16 -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
|
||||||
|
{
|
||||||
|
hip_fp8 res{__bfloat162float(a)};
|
||||||
|
return res.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
// float -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
|
||||||
|
{
|
||||||
|
hip_fp8 f8(a);
|
||||||
|
return f8.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> float4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
|
||||||
|
{
|
||||||
|
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||||
|
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// float2 -> half2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
|
||||||
|
{
|
||||||
|
union {
|
||||||
|
half2 float16;
|
||||||
|
uint32_t uint32;
|
||||||
|
};
|
||||||
|
|
||||||
|
float16 = __float22half2_rn(a);
|
||||||
|
return uint32;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Float4 -> half2x2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
|
||||||
|
{
|
||||||
|
uint2 b;
|
||||||
|
float2 val;
|
||||||
|
val.x = a.x.x;
|
||||||
|
val.y = a.x.y;
|
||||||
|
b.x = vec_conversion<uint32_t, float2>(val);
|
||||||
|
|
||||||
|
val.x = a.y.x;
|
||||||
|
val.y = a.y.y;
|
||||||
|
b.y = vec_conversion<uint32_t, float2>(val);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Float4 -> float4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
|
||||||
|
{
|
||||||
|
float4 b;
|
||||||
|
b.x = a.x.x;
|
||||||
|
b.y = a.x.y;
|
||||||
|
b.z = a.y.x;
|
||||||
|
b.w = a.y.y;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Float8 -> half2x4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
|
||||||
|
{
|
||||||
|
uint4 b;
|
||||||
|
b.x = vec_conversion<uint32_t, float2>(a.x);
|
||||||
|
b.y = vec_conversion<uint32_t, float2>(a.y);
|
||||||
|
b.z = vec_conversion<uint32_t, float2>(a.z);
|
||||||
|
b.w = vec_conversion<uint32_t, float2>(a.w);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// float2 -> bfloat162
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a)
|
||||||
|
{
|
||||||
|
__nv_bfloat162 b = __float22bfloat162_rn(a);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Float4 -> bfloat162x2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_& a)
|
||||||
|
{
|
||||||
|
bf16_4_t b;
|
||||||
|
b.x = __float22bfloat162_rn(a.x);
|
||||||
|
b.y = __float22bfloat162_rn(a.y);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Float8 -> bfloat162x4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_& a)
|
||||||
|
{
|
||||||
|
bf16_8_t b;
|
||||||
|
b.x = __float22bfloat162_rn(a.x);
|
||||||
|
b.y = __float22bfloat162_rn(a.y);
|
||||||
|
b.z = __float22bfloat162_rn(a.z);
|
||||||
|
b.w = __float22bfloat162_rn(a.w);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* Scaled and vectorized conversions, for data exchange between high and low precision domains
|
||||||
|
|
||||||
|
Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale )
|
||||||
|
s.t.
|
||||||
|
Quantize(HP / scale) => FP8
|
||||||
|
Dequant(FP8) * scale => HP
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
// fp8 -> half
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale)
|
||||||
|
{
|
||||||
|
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||||
|
__half_raw res;
|
||||||
|
res.data = static_cast<float>(f8) * scale;
|
||||||
|
return res.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> half2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, const float scale)
|
||||||
|
{
|
||||||
|
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||||
|
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||||
|
union {
|
||||||
|
__half2_raw h2r;
|
||||||
|
uint32_t ui32;
|
||||||
|
} tmp;
|
||||||
|
tmp.h2r.x.data = f2[0] * scale;
|
||||||
|
tmp.h2r.y.data = f2[1] * scale;
|
||||||
|
return tmp.ui32;
|
||||||
|
#else
|
||||||
|
union {
|
||||||
|
uint16_t u16[2];
|
||||||
|
uint32_t u32;
|
||||||
|
} tmp;
|
||||||
|
|
||||||
|
tmp.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||||
|
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
|
||||||
|
return tmp.u32;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> half2x2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale)
|
||||||
|
{
|
||||||
|
union {
|
||||||
|
uint2 u32x2;
|
||||||
|
uint32_t u32[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
||||||
|
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||||
|
return tmp.u32x2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> half2x4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale)
|
||||||
|
{
|
||||||
|
union {
|
||||||
|
uint4 u64x2;
|
||||||
|
uint2 u64[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
|
||||||
|
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
|
||||||
|
return tmp.u64x2;
|
||||||
|
}
|
||||||
|
|
||||||
|
using __nv_bfloat16 = __hip_bfloat16;
|
||||||
|
|
||||||
|
// fp8 -> __nv_bfloat16
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, const float scale)
|
||||||
|
{
|
||||||
|
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||||
|
float f{f8};
|
||||||
|
return __float2bfloat16(f * scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
using __nv_bfloat162 = __hip_bfloat162;
|
||||||
|
|
||||||
|
// fp8x2 -> __nv_bfloat162
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, const float scale)
|
||||||
|
{
|
||||||
|
__nv_bfloat162 res;
|
||||||
|
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
|
||||||
|
res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> bf16_4_t
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, const float scale)
|
||||||
|
{
|
||||||
|
bf16_4_t res;
|
||||||
|
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
|
||||||
|
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> bf16_8_t
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale)
|
||||||
|
{
|
||||||
|
bf16_4_t tmp1, tmp2;
|
||||||
|
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
|
||||||
|
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
|
||||||
|
bf16_8_t res;
|
||||||
|
res.x = tmp1.x;
|
||||||
|
res.y = tmp1.y;
|
||||||
|
res.z = tmp2.x;
|
||||||
|
res.w = tmp2.y;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> float
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(const uint8_t& a, const float scale)
|
||||||
|
{
|
||||||
|
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
||||||
|
return static_cast<float>(fp8) * scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> float2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale)
|
||||||
|
{
|
||||||
|
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||||
|
float2 res;
|
||||||
|
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||||
|
res.x = f2[0] * scale;
|
||||||
|
res.y = f2[1] * scale;
|
||||||
|
return res;
|
||||||
|
#else
|
||||||
|
float2 res;
|
||||||
|
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||||
|
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
|
||||||
|
return res;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> float4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale)
|
||||||
|
{
|
||||||
|
Float4_ res;
|
||||||
|
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
|
||||||
|
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> float8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale)
|
||||||
|
{
|
||||||
|
Float4_ tmp1, tmp2;
|
||||||
|
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
|
||||||
|
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
|
||||||
|
Float8_ res;
|
||||||
|
res.x = tmp1.x;
|
||||||
|
res.y = tmp1.y;
|
||||||
|
res.z = tmp2.x;
|
||||||
|
res.w = tmp2.y;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* Quantize(HP / scale) => FP8 */
|
||||||
|
|
||||||
|
// TODO(Hai): vectorized to add
|
||||||
|
|
||||||
|
// half -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale)
|
||||||
|
{
|
||||||
|
__half_raw tmp;
|
||||||
|
tmp.x = a;
|
||||||
|
|
||||||
|
hip_fp8 f8{static_cast<float>(tmp.data)/scale};
|
||||||
|
return f8.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
// bf16 -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a, const float scale)
|
||||||
|
{
|
||||||
|
hip_fp8 res{__bfloat162float(a)/scale};
|
||||||
|
return res.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
// float -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(const float& a, const float scale)
|
||||||
|
{
|
||||||
|
hip_fp8 f8(a/scale);
|
||||||
|
return f8.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> float4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale)
|
||||||
|
{
|
||||||
|
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
||||||
|
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
} // namespace vllm
|
||||||
103
csrc/quantization/fp8/fp8_cuda_kernels.cu
Normal file
103
csrc/quantization/fp8/fp8_cuda_kernels.cu
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "cuda_compat.h"
|
||||||
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||||
|
float old;
|
||||||
|
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) :
|
||||||
|
__uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||||
|
|
||||||
|
return old;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the absolute maximum m of the input tensor and store
|
||||||
|
// m / float8_e4m3::max() in *scale. Each thread block performs a
|
||||||
|
// reduction tree and the memory in scale is atomically updated.
|
||||||
|
// So to get the right answer, *scale needs to be initialized to
|
||||||
|
// a value <= 0.0 and we need to wait for all thread blocks to
|
||||||
|
// finish before consuming *scale.
|
||||||
|
template<typename scalar_t>
|
||||||
|
__global__ void segmented_max_reduction(
|
||||||
|
float* __restrict__ scale,
|
||||||
|
const scalar_t* __restrict__ input,
|
||||||
|
int64_t num_elems) {
|
||||||
|
__shared__ float cache[1024];
|
||||||
|
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
// First store maximum for all values processes by
|
||||||
|
// the current thread in cache[threadIdx.x]
|
||||||
|
scalar_t tmp = 0.0;
|
||||||
|
while (i < num_elems) {
|
||||||
|
float x = static_cast<float>(input[i]);
|
||||||
|
tmp = max(tmp, fabs(x));
|
||||||
|
i += blockDim.x * gridDim.x;
|
||||||
|
}
|
||||||
|
cache[threadIdx.x] = tmp;
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Now perform parallel reduction within the thread block
|
||||||
|
int ib = blockDim.x / 2;
|
||||||
|
while (ib != 0) {
|
||||||
|
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
|
||||||
|
cache[threadIdx.x] = cache[threadIdx.x + ib];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
ib /= 2;
|
||||||
|
}
|
||||||
|
// Finally, since cache[0] contains the maximum for this thread block,
|
||||||
|
// atomically write the max to the target location
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
atomicMaxFloat(scale, cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t>
|
||||||
|
__global__ void scaled_fp8_quant_kernel(
|
||||||
|
c10::Float8_e4m3fn* __restrict__ out,
|
||||||
|
const scalar_t* __restrict__ input,
|
||||||
|
const float* __restrict__ scale,
|
||||||
|
int64_t num_elems) {
|
||||||
|
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
while (i < num_elems) {
|
||||||
|
out[i] = static_cast<c10::Float8_e4m3fn>(input[i] / *scale);
|
||||||
|
i += blockDim.x * gridDim.x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
void scaled_fp8_quant(
|
||||||
|
torch::Tensor& out, // [..., d]
|
||||||
|
torch::Tensor& input, // [..., d]
|
||||||
|
torch::Tensor& scale) // [1]
|
||||||
|
{
|
||||||
|
int64_t num_tokens = input.numel() / input.size(-1);
|
||||||
|
int64_t num_elems = input.numel();
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(1024);
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
input.scalar_type(),
|
||||||
|
"scaled_fp8_quant_kernel",
|
||||||
|
[&] {
|
||||||
|
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
|
scale.data_ptr<float>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
num_elems);
|
||||||
|
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<c10::Float8_e4m3fn>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
scale.data_ptr<float>(),
|
||||||
|
num_elems);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
@@ -2067,7 +2067,7 @@ void gptq_shuffle
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
|
||||||
vllm::gptq::shuffle_exllama_weight(
|
vllm::gptq::shuffle_exllama_weight(
|
||||||
(uint32_t*) q_weight.data_ptr(),
|
(uint32_t*) q_weight.data_ptr(),
|
||||||
q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(),
|
q_perm.device().is_meta() || q_perm.numel() == 0 ? NULL : (int*) q_perm.data_ptr(),
|
||||||
q_weight.size(0) * 32 / bit,
|
q_weight.size(0) * 32 / bit,
|
||||||
q_weight.size(1),
|
q_weight.size(1),
|
||||||
bit
|
bit
|
||||||
|
|||||||
@@ -20,43 +20,45 @@
|
|||||||
#include "cuda_compat.h"
|
#include "cuda_compat.h"
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
template<typename T, int numLanes = WARP_SIZE>
|
||||||
template<typename T>
|
|
||||||
__inline__ __device__ T warpReduceSum(T val) {
|
__inline__ __device__ T warpReduceSum(T val) {
|
||||||
#pragma unroll
|
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
|
||||||
for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1)
|
"numLanes is not a positive power of 2!");
|
||||||
|
static_assert(numLanes <= WARP_SIZE);
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
|
||||||
val += VLLM_SHFL_XOR_SYNC(val, mask);
|
val += VLLM_SHFL_XOR_SYNC(val, mask);
|
||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
|
|
||||||
__inline__ __device__ constexpr int _calculateLaneMask(int warp_size) {
|
// Helper function to return the next largest power of 2
|
||||||
return warp_size - 1;
|
static constexpr int _nextPow2(unsigned int num) {
|
||||||
}
|
if (num <= 1) return num;
|
||||||
|
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||||
__inline__ __device__ constexpr int _calculateWidShift(int warp_size) {
|
|
||||||
return 5 + (warp_size >> 6);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Calculate the sum of all elements in a block */
|
/* Calculate the sum of all elements in a block */
|
||||||
template<typename T>
|
template<typename T, int maxBlockSize = 1024>
|
||||||
__inline__ __device__ T blockReduceSum(T val) {
|
__inline__ __device__ T blockReduceSum(T val) {
|
||||||
static __shared__ T shared[WARP_SIZE];
|
static_assert(maxBlockSize <= 1024);
|
||||||
constexpr auto LANE_MASK = _calculateLaneMask(WARP_SIZE);
|
if constexpr (maxBlockSize > WARP_SIZE) {
|
||||||
constexpr auto WID_SHIFT = _calculateWidShift(WARP_SIZE);
|
val = warpReduceSum<T>(val);
|
||||||
int lane = threadIdx.x & LANE_MASK;
|
// Calculates max number of lanes that need to participate in the last warpReduce
|
||||||
int wid = threadIdx.x >> WID_SHIFT;
|
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
|
static __shared__ T shared[maxActiveLanes];
|
||||||
|
int lane = threadIdx.x % WARP_SIZE;
|
||||||
|
int wid = threadIdx.x / WARP_SIZE;
|
||||||
|
if (lane == 0)
|
||||||
|
shared[wid] = val;
|
||||||
|
|
||||||
val = warpReduceSum<T>(val);
|
__syncthreads();
|
||||||
|
|
||||||
if (lane == 0)
|
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : (T)(0.0f);
|
||||||
shared[wid] = val;
|
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
|
||||||
|
} else {
|
||||||
__syncthreads();
|
// A single warpReduce is equal to blockReduce
|
||||||
|
val = warpReduceSum<T, _nextPow2(maxBlockSize)>(val);
|
||||||
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
}
|
||||||
// blockDim.x is not divided by 32
|
|
||||||
val = (threadIdx.x < (blockDim.x / (WARP_SIZE * 1.0f))) ? shared[lane] : (T)(0.0f);
|
|
||||||
val = warpReduceSum<T>(val);
|
|
||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,3 +8,5 @@ sphinx-argparse
|
|||||||
pydantic
|
pydantic
|
||||||
-f https://download.pytorch.org/whl/cpu
|
-f https://download.pytorch.org/whl/cpu
|
||||||
torch
|
torch
|
||||||
|
py-cpuinfo
|
||||||
|
transformers
|
||||||
|
|||||||
@@ -13,12 +13,12 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from sphinx.ext import autodoc
|
from sphinx.ext import autodoc
|
||||||
|
|
||||||
sys.path.insert(0, os.path.abspath(os.path.join('..', '..')))
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
sys.path.append(os.path.abspath("../.."))
|
||||||
|
|
||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
@@ -48,7 +48,7 @@ templates_path = ['_templates']
|
|||||||
# List of patterns, relative to source directory, that match files and
|
# List of patterns, relative to source directory, that match files and
|
||||||
# directories to ignore when looking for source files.
|
# directories to ignore when looking for source files.
|
||||||
# This pattern also affects html_static_path and html_extra_path.
|
# This pattern also affects html_static_path and html_extra_path.
|
||||||
exclude_patterns = []
|
exclude_patterns: List[str] = ["**/*.template.rst"]
|
||||||
|
|
||||||
# Exclude the prompt "$" when copying code
|
# Exclude the prompt "$" when copying code
|
||||||
copybutton_prompt_text = r"\$ "
|
copybutton_prompt_text = r"\$ "
|
||||||
@@ -73,8 +73,16 @@ html_theme_options = {
|
|||||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||||
# html_static_path = ['_static']
|
# html_static_path = ['_static']
|
||||||
|
|
||||||
|
|
||||||
|
# Generate additional rst documentation here.
|
||||||
|
def setup(app):
|
||||||
|
from docs.source.generate_examples import generate_examples
|
||||||
|
generate_examples()
|
||||||
|
|
||||||
|
|
||||||
# Mock out external dependencies here.
|
# Mock out external dependencies here.
|
||||||
autodoc_mock_imports = [
|
autodoc_mock_imports = [
|
||||||
|
"cpuinfo",
|
||||||
"torch",
|
"torch",
|
||||||
"transformers",
|
"transformers",
|
||||||
"psutil",
|
"psutil",
|
||||||
@@ -84,6 +92,7 @@ autodoc_mock_imports = [
|
|||||||
"vllm._C",
|
"vllm._C",
|
||||||
"numpy",
|
"numpy",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
|
"tensorizer",
|
||||||
]
|
]
|
||||||
|
|
||||||
for mock_target in autodoc_mock_imports:
|
for mock_target in autodoc_mock_imports:
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
|
|
||||||
AsyncLLMEngine
|
AsyncLLMEngine
|
||||||
=================================
|
=================================
|
||||||
|
|
||||||
.. autoclass:: vllm.engine.async_llm_engine.AsyncLLMEngine
|
.. autoclass:: vllm.AsyncLLMEngine
|
||||||
:members: generate, abort
|
:members:
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
LLMEngine
|
LLMEngine
|
||||||
=================================
|
=================================
|
||||||
|
|
||||||
.. autoclass:: vllm.engine.llm_engine.LLMEngine
|
.. autoclass:: vllm.LLMEngine
|
||||||
:members: add_request, abort_request, step
|
:members:
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
Sampling Params
|
Sampling Params
|
||||||
===============
|
===============
|
||||||
|
|
||||||
.. automodule:: vllm.sampling_params.SamplingParams
|
.. autoclass:: vllm.SamplingParams
|
||||||
|
:members:
|
||||||
|
|||||||
61
docs/source/generate_examples.py
Normal file
61
docs/source/generate_examples.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def fix_case(text: str) -> str:
|
||||||
|
subs = [
|
||||||
|
("api", "API"),
|
||||||
|
("llm", "LLM"),
|
||||||
|
("vllm", "vLLM"),
|
||||||
|
("openai", "OpenAI"),
|
||||||
|
("multilora", "MultiLoRA"),
|
||||||
|
]
|
||||||
|
for sub in subs:
|
||||||
|
text = re.sub(*sub, text, flags=re.IGNORECASE)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def underline(title: str, character: str = "=") -> str:
|
||||||
|
return f"{title}\n{character * len(title)}"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_title(filename: str) -> str:
|
||||||
|
# Turn filename into a title
|
||||||
|
title = filename.replace("_", " ").title()
|
||||||
|
# Handle acronyms and names
|
||||||
|
title = fix_case(title)
|
||||||
|
# Underline title
|
||||||
|
title = underline(title)
|
||||||
|
return title
|
||||||
|
|
||||||
|
|
||||||
|
def generate_examples():
|
||||||
|
root_dir = Path(__file__).parent.parent.parent.resolve()
|
||||||
|
|
||||||
|
# Source paths
|
||||||
|
script_dir = root_dir / "examples"
|
||||||
|
script_paths = sorted(script_dir.glob("*.py"))
|
||||||
|
|
||||||
|
# Destination paths
|
||||||
|
doc_dir = root_dir / "docs/source/getting_started/examples"
|
||||||
|
doc_paths = [doc_dir / f"{path.stem}.rst" for path in script_paths]
|
||||||
|
|
||||||
|
# Generate the example docs for each example script
|
||||||
|
for script_path, doc_path in zip(script_paths, doc_paths):
|
||||||
|
script_url = f"https://github.com/vllm-project/vllm/blob/main/examples/{script_path.name}"
|
||||||
|
# Make script_path relative to doc_path and call it include_path
|
||||||
|
include_path = '../../../..' / script_path.relative_to(root_dir)
|
||||||
|
content = (f"{generate_title(doc_path.stem)}\n\n"
|
||||||
|
f"Source {script_url}.\n\n"
|
||||||
|
f".. literalinclude:: {include_path}\n"
|
||||||
|
" :language: python\n"
|
||||||
|
" :linenos:\n")
|
||||||
|
with open(doc_path, "w+") as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
# Generate the toctree for the example scripts
|
||||||
|
with open(doc_dir / "examples_index.template.rst") as f:
|
||||||
|
examples_index = f.read()
|
||||||
|
with open(doc_dir / "examples_index.rst", "w+") as f:
|
||||||
|
example_docs = "\n ".join(path.stem for path in script_paths)
|
||||||
|
f.write(examples_index.replace(r"%EXAMPLE_DOCS%", example_docs))
|
||||||
87
docs/source/getting_started/cpu-installation.rst
Normal file
87
docs/source/getting_started/cpu-installation.rst
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
.. _installation_cpu:
|
||||||
|
|
||||||
|
Installation with CPU
|
||||||
|
========================
|
||||||
|
|
||||||
|
vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32 and BF16.
|
||||||
|
|
||||||
|
Table of contents:
|
||||||
|
|
||||||
|
#. :ref:`Requirements <cpu_backend_requirements>`
|
||||||
|
#. :ref:`Quick start using Dockerfile <cpu_backend_quick_start_dockerfile>`
|
||||||
|
#. :ref:`Build from source <build_cpu_backend_from_source>`
|
||||||
|
#. :ref:`Performance tips <cpu_backend_performance_tips>`
|
||||||
|
|
||||||
|
.. _cpu_backend_requirements:
|
||||||
|
|
||||||
|
Requirements
|
||||||
|
------------
|
||||||
|
|
||||||
|
* OS: Linux
|
||||||
|
* Compiler: gcc/g++>=12.3.0 (recommended)
|
||||||
|
* Instruction set architecture (ISA) requirement: AVX512 is required.
|
||||||
|
|
||||||
|
.. _cpu_backend_quick_start_dockerfile:
|
||||||
|
|
||||||
|
Quick start using Dockerfile
|
||||||
|
----------------------------
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ docker build -f Dockerfile.cpu -t vllm-cpu-env --shm-size=4g .
|
||||||
|
$ docker run -it \
|
||||||
|
--rm \
|
||||||
|
--network=host \
|
||||||
|
--cpuset-cpus=<cpu-id-list, optional> \
|
||||||
|
--cpuset-mems=<memory-node, optional> \
|
||||||
|
vllm-cpu-env
|
||||||
|
|
||||||
|
.. _build_cpu_backend_from_source:
|
||||||
|
|
||||||
|
Build from source
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
- First, install required compiler. We recommend to use ``gcc/g++ >= 12.3.0`` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ sudo apt-get update -y
|
||||||
|
$ sudo apt-get install -y gcc-12 g++-12
|
||||||
|
$ sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
|
||||||
|
|
||||||
|
- Second, install Python packages for vLLM CPU backend building:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ pip install --upgrade pip
|
||||||
|
$ pip install wheel packaging ninja setuptools>=49.4.0 numpy
|
||||||
|
$ pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
|
|
||||||
|
- Finally, build and install vLLM CPU backend:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ VLLM_TARGET_DEVICE=cpu python setup.py install
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
- BF16 is the default data type in the current CPU backend (that means the backend will cast FP16 to BF16), and is compatible will all CPUs with AVX512 ISA support.
|
||||||
|
|
||||||
|
- AVX512_BF16 is an extension ISA provides native BF16 data type conversion and vector product instructions, will brings some performance improvement compared with pure AVX512. The CPU backend build script will check the host CPU flags to determine whether to enable AVX512_BF16.
|
||||||
|
|
||||||
|
- If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building.
|
||||||
|
|
||||||
|
.. _cpu_backend_performance_tips:
|
||||||
|
|
||||||
|
Performance tips
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
- vLLM CPU backend uses environment variable ``VLLM_CPU_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.
|
||||||
|
|
||||||
|
- vLLM CPU backend uses OpenMP for thread-parallel computation. If you want the best performance on CPU, it will be very critical to isolate CPU cores for OpenMP threads with other thread pools (like web-service event-loop), to avoid CPU oversubscription.
|
||||||
|
|
||||||
|
- If using vLLM CPU backend on a bare-metal machine, it is recommended to disable the hyper-threading.
|
||||||
|
|
||||||
|
- If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores and memory nodes, to avoid the remote memory node access. ``numactl`` is an useful tool for CPU core and memory binding on NUMA platform. Besides, ``--cpuset-cpus`` and ``--cpuset-mems`` arguments of ``docker run`` are also useful.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
Examples
|
||||||
|
=================================
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 1
|
||||||
|
:caption: Scripts
|
||||||
|
|
||||||
|
%EXAMPLE_DOCS%
|
||||||
@@ -19,7 +19,7 @@ You can install vLLM using pip:
|
|||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
$ # (Optional) Create a new conda environment.
|
$ # (Recommended) Create a new conda environment.
|
||||||
$ conda create -n myenv python=3.9 -y
|
$ conda create -n myenv python=3.9 -y
|
||||||
$ conda activate myenv
|
$ conda activate myenv
|
||||||
|
|
||||||
@@ -28,24 +28,19 @@ You can install vLLM using pip:
|
|||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
As of now, vLLM's binaries are compiled on CUDA 12.1 by default.
|
As of now, vLLM's binaries are compiled with CUDA 12.1 and public PyTorch release versions by default.
|
||||||
However, you can install vLLM with CUDA 11.8 by running:
|
We also provide vLLM binaries compiled with CUDA 11.8 and public PyTorch release versions:
|
||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
$ # Install vLLM with CUDA 11.8.
|
$ # Install vLLM with CUDA 11.8.
|
||||||
$ export VLLM_VERSION=0.2.4
|
$ export VLLM_VERSION=0.4.0
|
||||||
$ export PYTHON_VERSION=39
|
$ export PYTHON_VERSION=39
|
||||||
$ pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl
|
$ pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118
|
||||||
|
|
||||||
$ # Re-install PyTorch with CUDA 11.8.
|
In order to be performant, vLLM has to compile many cuda kernels. The compilation unfortunately introduces binary incompatibility with other CUDA versions and PyTorch versions, even for the same PyTorch version with different building configurations.
|
||||||
$ pip uninstall torch -y
|
|
||||||
$ pip install torch --upgrade --index-url https://download.pytorch.org/whl/cu118
|
|
||||||
|
|
||||||
$ # Re-install xFormers with CUDA 11.8.
|
|
||||||
$ pip uninstall xformers -y
|
|
||||||
$ pip install --upgrade xformers --index-url https://download.pytorch.org/whl/cu118
|
|
||||||
|
|
||||||
|
Therefore, it is recommended to install vLLM with a **fresh new** conda environment. If either you have a different CUDA version or you want to use an existing PyTorch installation, you need to build vLLM from source. See below for instructions.
|
||||||
|
|
||||||
.. _build_from_source:
|
.. _build_from_source:
|
||||||
|
|
||||||
@@ -77,12 +72,16 @@ You can also build and install vLLM from source:
|
|||||||
$ # Use `--ipc=host` to make sure the shared memory is large enough.
|
$ # Use `--ipc=host` to make sure the shared memory is large enough.
|
||||||
$ docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:23.10-py3
|
$ docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:23.10-py3
|
||||||
|
|
||||||
.. note::
|
If you don't want to use docker, it is recommended to have a full installation of CUDA Toolkit. You can download and install it from `the official website <https://developer.nvidia.com/cuda-toolkit-archive>`_. After installation, set the environment variable `CUDA_HOME` to the installation path of CUDA Toolkit, and make sure that the `nvcc` compiler is in your `PATH`, e.g.:
|
||||||
If you are developing the C++ backend of vLLM, consider building vLLM with
|
|
||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
$ python setup.py develop
|
$ export CUDA_HOME=/usr/local/cuda
|
||||||
|
$ export PATH="${CUDA_HOME}/bin:$PATH"
|
||||||
|
|
||||||
since it will give you incremental builds. The downside is that this method
|
Here is a sanity check to verify that the CUDA Toolkit is correctly installed:
|
||||||
is `deprecated by setuptools <https://github.com/pypa/setuptools/issues/917>`_.
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ nvcc --version # verify that nvcc is in your PATH
|
||||||
|
$ ${CUDA_HOME}/bin/nvcc --version # verify that nvcc is in your CUDA_HOME
|
||||||
|
|||||||
@@ -63,7 +63,9 @@ Documentation
|
|||||||
getting_started/installation
|
getting_started/installation
|
||||||
getting_started/amd-installation
|
getting_started/amd-installation
|
||||||
getting_started/neuron-installation
|
getting_started/neuron-installation
|
||||||
|
getting_started/cpu-installation
|
||||||
getting_started/quickstart
|
getting_started/quickstart
|
||||||
|
getting_started/examples/examples_index
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
@@ -90,7 +92,8 @@ Documentation
|
|||||||
:caption: Quantization
|
:caption: Quantization
|
||||||
|
|
||||||
quantization/auto_awq
|
quantization/auto_awq
|
||||||
quantization/fp8_e5m2_kv_cache
|
quantization/fp8_e5m2_kvcache
|
||||||
|
quantization/fp8_e4m3_kvcache
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
|
|||||||
@@ -21,6 +21,8 @@ This document provides a high-level guide on integrating a `HuggingFace Transfor
|
|||||||
Start by forking our `GitHub`_ repository and then :ref:`build it from source <build_from_source>`.
|
Start by forking our `GitHub`_ repository and then :ref:`build it from source <build_from_source>`.
|
||||||
This gives you the ability to modify the codebase and test your model.
|
This gives you the ability to modify the codebase and test your model.
|
||||||
|
|
||||||
|
.. tip::
|
||||||
|
If you don't want to fork the repository and modify vLLM's codebase, please refer to the "Out-of-Tree Model Integration" section below.
|
||||||
|
|
||||||
1. Bring your model code
|
1. Bring your model code
|
||||||
------------------------
|
------------------------
|
||||||
@@ -93,4 +95,29 @@ This method should load the weights from the HuggingFace's checkpoint file and a
|
|||||||
5. Register your model
|
5. Register your model
|
||||||
----------------------
|
----------------------
|
||||||
|
|
||||||
Finally, include your :code:`*ForCausalLM` class in `vllm/model_executor/models/__init__.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/__init__.py>`_ and register it to the :code:`_MODEL_REGISTRY` in `vllm/model_executor/model_loader.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/model_loader.py>`_.
|
Finally, register your :code:`*ForCausalLM` class to the :code:`_MODELS` in `vllm/model_executor/models/__init__.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/__init__.py>`_.
|
||||||
|
|
||||||
|
6. Out-of-Tree Model Integration
|
||||||
|
--------------------------------------------
|
||||||
|
|
||||||
|
We also provide a way to integrate a model without modifying the vLLM codebase. Step 2, 3, 4 are still required, but you can skip step 1 and 5.
|
||||||
|
|
||||||
|
Just add the following lines in your code:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from vllm import ModelRegistry
|
||||||
|
from your_code import YourModelForCausalLM
|
||||||
|
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)
|
||||||
|
|
||||||
|
If you are running api server with `python -m vllm.entrypoints.openai.api_server args`, you can wrap the entrypoint with the following code:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from vllm import ModelRegistry
|
||||||
|
from your_code import YourModelForCausalLM
|
||||||
|
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)
|
||||||
|
import runpy
|
||||||
|
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
|
||||||
|
|
||||||
|
Save the above code in a file and run it with `python your_file.py args`.
|
||||||
|
|||||||
@@ -5,116 +5,19 @@ Engine Arguments
|
|||||||
|
|
||||||
Below, you can find an explanation of every engine argument for vLLM:
|
Below, you can find an explanation of every engine argument for vLLM:
|
||||||
|
|
||||||
.. option:: --model <model_name_or_path>
|
.. argparse::
|
||||||
|
:module: vllm.engine.arg_utils
|
||||||
|
:func: _engine_args_parser
|
||||||
|
:prog: -m vllm.entrypoints.openai.api_server
|
||||||
|
:nodefaultconst:
|
||||||
|
|
||||||
Name or path of the huggingface model to use.
|
Async Engine Arguments
|
||||||
|
----------------------
|
||||||
|
|
||||||
.. option:: --tokenizer <tokenizer_name_or_path>
|
Below are the additional arguments related to the asynchronous engine:
|
||||||
|
|
||||||
Name or path of the huggingface tokenizer to use.
|
.. argparse::
|
||||||
|
:module: vllm.engine.arg_utils
|
||||||
.. option:: --revision <revision>
|
:func: _async_engine_args_parser
|
||||||
|
:prog: -m vllm.entrypoints.openai.api_server
|
||||||
The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
|
:nodefaultconst:
|
||||||
|
|
||||||
.. option:: --tokenizer-revision <revision>
|
|
||||||
|
|
||||||
The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
|
|
||||||
|
|
||||||
.. option:: --tokenizer-mode {auto,slow}
|
|
||||||
|
|
||||||
The tokenizer mode.
|
|
||||||
|
|
||||||
* "auto" will use the fast tokenizer if available.
|
|
||||||
* "slow" will always use the slow tokenizer.
|
|
||||||
|
|
||||||
.. option:: --trust-remote-code
|
|
||||||
|
|
||||||
Trust remote code from huggingface.
|
|
||||||
|
|
||||||
.. option:: --download-dir <directory>
|
|
||||||
|
|
||||||
Directory to download and load the weights, default to the default cache dir of huggingface.
|
|
||||||
|
|
||||||
.. option:: --load-format {auto,pt,safetensors,npcache,dummy}
|
|
||||||
|
|
||||||
The format of the model weights to load.
|
|
||||||
|
|
||||||
* "auto" will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available.
|
|
||||||
* "pt" will load the weights in the pytorch bin format.
|
|
||||||
* "safetensors" will load the weights in the safetensors format.
|
|
||||||
* "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading.
|
|
||||||
* "dummy" will initialize the weights with random values, mainly for profiling.
|
|
||||||
|
|
||||||
.. option:: --dtype {auto,half,float16,bfloat16,float,float32}
|
|
||||||
|
|
||||||
Data type for model weights and activations.
|
|
||||||
|
|
||||||
* "auto" will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.
|
|
||||||
* "half" for FP16. Recommended for AWQ quantization.
|
|
||||||
* "float16" is the same as "half".
|
|
||||||
* "bfloat16" for a balance between precision and range.
|
|
||||||
* "float" is shorthand for FP32 precision.
|
|
||||||
* "float32" for FP32 precision.
|
|
||||||
|
|
||||||
.. option:: --max-model-len <length>
|
|
||||||
|
|
||||||
Model context length. If unspecified, will be automatically derived from the model config.
|
|
||||||
|
|
||||||
.. option:: --worker-use-ray
|
|
||||||
|
|
||||||
Use Ray for distributed serving, will be automatically set when using more than 1 GPU.
|
|
||||||
|
|
||||||
.. option:: --pipeline-parallel-size (-pp) <size>
|
|
||||||
|
|
||||||
Number of pipeline stages.
|
|
||||||
|
|
||||||
.. option:: --tensor-parallel-size (-tp) <size>
|
|
||||||
|
|
||||||
Number of tensor parallel replicas.
|
|
||||||
|
|
||||||
.. option:: --max-parallel-loading-workers <workers>
|
|
||||||
|
|
||||||
Load model sequentially in multiple batches, to avoid RAM OOM when using tensor parallel and large models.
|
|
||||||
|
|
||||||
.. option:: --block-size {8,16,32}
|
|
||||||
|
|
||||||
Token block size for contiguous chunks of tokens.
|
|
||||||
|
|
||||||
.. option:: --enable-prefix-caching
|
|
||||||
|
|
||||||
Enables automatic prefix caching
|
|
||||||
|
|
||||||
.. option:: --seed <seed>
|
|
||||||
|
|
||||||
Random seed for operations.
|
|
||||||
|
|
||||||
.. option:: --swap-space <size>
|
|
||||||
|
|
||||||
CPU swap space size (GiB) per GPU.
|
|
||||||
|
|
||||||
.. option:: --gpu-memory-utilization <fraction>
|
|
||||||
|
|
||||||
The fraction of GPU memory to be used for the model executor, which can range from 0 to 1.
|
|
||||||
For example, a value of 0.5 would imply 50% GPU memory utilization.
|
|
||||||
If unspecified, will use the default value of 0.9.
|
|
||||||
|
|
||||||
.. option:: --max-num-batched-tokens <tokens>
|
|
||||||
|
|
||||||
Maximum number of batched tokens per iteration.
|
|
||||||
|
|
||||||
.. option:: --max-num-seqs <sequences>
|
|
||||||
|
|
||||||
Maximum number of sequences per iteration.
|
|
||||||
|
|
||||||
.. option:: --max-paddings <paddings>
|
|
||||||
|
|
||||||
Maximum number of paddings in a batch.
|
|
||||||
|
|
||||||
.. option:: --disable-log-stats
|
|
||||||
|
|
||||||
Disable logging statistics.
|
|
||||||
|
|
||||||
.. option:: --quantization (-q) {awq,squeezellm,None}
|
|
||||||
|
|
||||||
Method used to quantize the weights.
|
|
||||||
@@ -80,16 +80,20 @@ Alongside each architecture, we include some popular models that use it.
|
|||||||
- :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc.
|
- :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc.
|
||||||
-
|
-
|
||||||
* - :code:`LlamaForCausalLM`
|
* - :code:`LlamaForCausalLM`
|
||||||
- LLaMA, LLaMA-2, Vicuna, Alpaca, Yi
|
- LLaMA, Llama 2, Meta Llama 3, Vicuna, Alpaca, Yi
|
||||||
- :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
|
- :code:`meta-llama/Meta-Llama-3-8B-Instruct`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
|
||||||
- ✅︎
|
- ✅︎
|
||||||
|
* - :code:`MiniCPMForCausalLM`
|
||||||
|
- MiniCPM
|
||||||
|
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.
|
||||||
|
-
|
||||||
* - :code:`MistralForCausalLM`
|
* - :code:`MistralForCausalLM`
|
||||||
- Mistral, Mistral-Instruct
|
- Mistral, Mistral-Instruct
|
||||||
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
|
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
|
||||||
- ✅︎
|
- ✅︎
|
||||||
* - :code:`MixtralForCausalLM`
|
* - :code:`MixtralForCausalLM`
|
||||||
- Mixtral-8x7B, Mixtral-8x7B-Instruct
|
- Mixtral-8x7B, Mixtral-8x7B-Instruct
|
||||||
- :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.
|
- :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, :code:`mistral-community/Mixtral-8x22B-v0.1`, etc.
|
||||||
- ✅︎
|
- ✅︎
|
||||||
* - :code:`MPTForCausalLM`
|
* - :code:`MPTForCausalLM`
|
||||||
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
||||||
@@ -164,3 +168,29 @@ Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-pr
|
|||||||
llm = LLM(model=..., revision=..., trust_remote_code=True) # Name or path of your model
|
llm = LLM(model=..., revision=..., trust_remote_code=True) # Name or path of your model
|
||||||
output = llm.generate("Hello, my name is")
|
output = llm.generate("Hello, my name is")
|
||||||
print(output)
|
print(output)
|
||||||
|
|
||||||
|
Model Support Policy
|
||||||
|
---------------------
|
||||||
|
|
||||||
|
At vLLM, we are committed to facilitating the integration and support of third-party models within our ecosystem. Our approach is designed to balance the need for robustness and the practical limitations of supporting a wide range of models. Here’s how we manage third-party model support:
|
||||||
|
|
||||||
|
1. **Community-Driven Support**: We encourage community contributions for adding new models. When a user requests support for a new model, we welcome pull requests (PRs) from the community. These contributions are evaluated primarily on the sensibility of the output they generate, rather than strict consistency with existing implementations such as those in transformers. **Call for contribution:** PRs coming directly from model vendors are greatly appreciated!
|
||||||
|
|
||||||
|
2. **Best-Effort Consistency**: While we aim to maintain a level of consistency between the models implemented in vLLM and other frameworks like transformers, complete alignment is not always feasible. Factors like acceleration techniques and the use of low-precision computations can introduce discrepancies. Our commitment is to ensure that the implemented models are functional and produce sensible results.
|
||||||
|
|
||||||
|
3. **Issue Resolution and Model Updates**: Users are encouraged to report any bugs or issues they encounter with third-party models. Proposed fixes should be submitted via PRs, with a clear explanation of the problem and the rationale behind the proposed solution. If a fix for one model impacts another, we rely on the community to highlight and address these cross-model dependencies. Note: for bugfix PRs, it is good etiquette to inform the original author to seek their feedback.
|
||||||
|
|
||||||
|
4. **Monitoring and Updates**: Users interested in specific models should monitor the commit history for those models (e.g., by tracking changes in the main/vllm/model_executor/models directory). This proactive approach helps users stay informed about updates and changes that may affect the models they use.
|
||||||
|
|
||||||
|
5. **Selective Focus**: Our resources are primarily directed towards models with significant user interest and impact. Models that are less frequently used may receive less attention, and we rely on the community to play a more active role in their upkeep and improvement.
|
||||||
|
|
||||||
|
Through this approach, vLLM fosters a collaborative environment where both the core development team and the broader community contribute to the robustness and diversity of the third-party models supported in our ecosystem.
|
||||||
|
|
||||||
|
Note that, as an inference engine, vLLM does not introduce new models. Therefore, all models supported by vLLM are third-party models in this regard.
|
||||||
|
|
||||||
|
We have the following levels of testing for models:
|
||||||
|
|
||||||
|
1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to `test_models.py <https://github.com/vllm-project/vllm/blob/main/tests/models/test_models.py>`_ and `test_big_models.py <https://github.com/vllm-project/vllm/blob/main/tests/models/test_big_models.py>`_ for the models that have passed this test.
|
||||||
|
2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test.
|
||||||
|
3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to `functionality tests <https://github.com/vllm-project/vllm/tree/main/tests>`_ and `examples <https://github.com/vllm-project/vllm/tree/main/examples>`_ for the models that have passed this test.
|
||||||
|
4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category.
|
||||||
|
|||||||
49
docs/source/quantization/fp8_e4m3_kvcache.rst
Normal file
49
docs/source/quantization/fp8_e4m3_kvcache.rst
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
.. _fp8_e4m3_kvcache:
|
||||||
|
|
||||||
|
FP8 E4M3 KV Cache
|
||||||
|
==================
|
||||||
|
|
||||||
|
Quantizing the KV cache to FP8 reduces its memory footprint. This increases the number of tokens that can be stored in the cache,
|
||||||
|
improving throughput. OCP (Open Compute Project www.opencompute.org) specifies two common 8-bit floating point data formats: E5M2
|
||||||
|
(5 exponent bits and 2 mantissa bits) and E4M3FN (4 exponent bits and 3 mantissa bits), often shortened as E4M3. One benefit of
|
||||||
|
the E4M3 format over E5M2 is that floating point numbers are represented in higher precision. However, the small dynamic range of
|
||||||
|
FP8 E4M3 (±240.0 can be represented) typically necessitates the use of a higher-precision (typically FP32) scaling factor alongside
|
||||||
|
each quantized tensor. For now, only per-tensor (scalar) scaling factors are supported. Development is ongoing to support scaling
|
||||||
|
factors of a finer granularity (e.g. per-channel).
|
||||||
|
|
||||||
|
These scaling factors can be specified by passing an optional quantization param JSON to the LLM engine at load time. If
|
||||||
|
this JSON is not specified, scaling factors default to 1.0. These scaling factors are typically obtained when running an
|
||||||
|
unquantized model through a quantizer tool (e.g. AMD quantizer or NVIDIA AMMO).
|
||||||
|
|
||||||
|
To install AMMO (AlgorithMic Model Optimization):
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo
|
||||||
|
|
||||||
|
Studies have shown that FP8 E4M3 quantization typically only minimally degrades inference accuracy. The most recent silicon
|
||||||
|
offerings e.g. AMD MI300, NVIDIA Hopper or later support native hardware conversion to and from fp32, fp16, bf16, etc.
|
||||||
|
Thus, LLM inference is greatly accelerated with minimal accuracy loss.
|
||||||
|
|
||||||
|
|
||||||
|
Here is an example of how to enable this feature:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# two float8_e4m3fn kv cache scaling factor files are provided under tests/fp8_kv, please refer to
|
||||||
|
# https://github.com/vllm-project/vllm/blob/main/examples/fp8/README.md to generate kv_cache_scales.json of your own.
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
sampling_params = SamplingParams(temperature=1.3, top_p=0.8)
|
||||||
|
llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
kv_cache_dtype="fp8",
|
||||||
|
quantization_param_path="./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json")
|
||||||
|
prompt = "London is the capital of"
|
||||||
|
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
|
||||||
|
print(out)
|
||||||
|
|
||||||
|
# output w/ scaling factors: England, the United Kingdom, and one of the world's leading financial,
|
||||||
|
# output w/o scaling factors: England, located in the southeastern part of the country. It is known
|
||||||
|
|
||||||
|
Note, current prefix caching doesn't work with FP8 KV cache enabled, forward_prefix kernel should handle different KV and cache type.
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
.. _fp8_e5m2_kv_cache:
|
.. _fp8_kv_cache:
|
||||||
|
|
||||||
FP8 E5M2 KV Cache
|
FP8 E5M2 KV Cache
|
||||||
==================
|
==================
|
||||||
@@ -21,7 +21,7 @@ Here is an example of how to enable this feature:
|
|||||||
# Create a sampling params object.
|
# Create a sampling params object.
|
||||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
# Create an LLM.
|
# Create an LLM.
|
||||||
llm = LLM(model="facebook/opt-125m", kv_cache_dtype="fp8_e5m2")
|
llm = LLM(model="facebook/opt-125m", kv_cache_dtype="fp8")
|
||||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||||
# that contain the prompt, generated text, and other information.
|
# that contain the prompt, generated text, and other information.
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
@@ -31,3 +31,6 @@ Here is an example of how to enable this feature:
|
|||||||
generated_text = output.outputs[0].text
|
generated_text = output.outputs[0].text
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
|
||||||
|
Note, current prefix caching doesn't work with FP8 KV cache enabled, forward_prefix kernel should handle different KV and cache type.
|
||||||
|
|
||||||
@@ -4,7 +4,7 @@ vLLM provides an HTTP server that implements OpenAI's [Completions](https://plat
|
|||||||
|
|
||||||
You can start the server using Python, or using [Docker](deploying_with_docker.rst):
|
You can start the server using Python, or using [Docker](deploying_with_docker.rst):
|
||||||
```bash
|
```bash
|
||||||
python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-hf --dtype float32 --api-key token-abc123
|
python -m vllm.entrypoints.openai.api_server --model mistralai/Mistral-7B-Instruct-v0.2 --dtype auto --api-key token-abc123
|
||||||
```
|
```
|
||||||
|
|
||||||
To call the server, you can use the official OpenAI Python client library, or any other HTTP client.
|
To call the server, you can use the official OpenAI Python client library, or any other HTTP client.
|
||||||
@@ -16,9 +16,8 @@ client = OpenAI(
|
|||||||
)
|
)
|
||||||
|
|
||||||
completion = client.chat.completions.create(
|
completion = client.chat.completions.create(
|
||||||
model="meta-llama/Llama-2-7b-hf",
|
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
|
||||||
{"role": "user", "content": "Hello!"}
|
{"role": "user", "content": "Hello!"}
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -38,9 +37,8 @@ Or directly merge them into the JSON payload if you are using HTTP call directly
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
completion = client.chat.completions.create(
|
completion = client.chat.completions.create(
|
||||||
model="meta-llama/Llama-2-7b-hf",
|
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
|
||||||
{"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}
|
{"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}
|
||||||
],
|
],
|
||||||
extra_body={
|
extra_body={
|
||||||
@@ -89,7 +87,7 @@ In order for the language model to support chat protocol, vLLM requires the mode
|
|||||||
a chat template in its tokenizer configuration. The chat template is a Jinja2 template that
|
a chat template in its tokenizer configuration. The chat template is a Jinja2 template that
|
||||||
specifies how are roles, messages, and other chat-specific tokens are encoded in the input.
|
specifies how are roles, messages, and other chat-specific tokens are encoded in the input.
|
||||||
|
|
||||||
An example chat template for `meta-llama/Llama-2-7b-chat-hf` can be found [here](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/09bd0f49e16738cdfaa6e615203e126038736eb0/tokenizer_config.json#L12)
|
An example chat template for `mistralai/Mistral-7B-Instruct-v0.2` can be found [here](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2#instruction-format)
|
||||||
|
|
||||||
Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those model,
|
Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those model,
|
||||||
you can manually specify their chat template in the `--chat-template` parameter with the file path to the chat
|
you can manually specify their chat template in the `--chat-template` parameter with the file path to the chat
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
.. _on_cloud:
|
.. _on_cloud:
|
||||||
|
|
||||||
Running on clouds with SkyPilot
|
Deploying and scaling up with SkyPilot
|
||||||
===============================
|
================================================
|
||||||
|
|
||||||
.. raw:: html
|
.. raw:: html
|
||||||
|
|
||||||
@@ -9,51 +9,75 @@ Running on clouds with SkyPilot
|
|||||||
<img src="https://imgur.com/yxtzPEu.png" alt="vLLM"/>
|
<img src="https://imgur.com/yxtzPEu.png" alt="vLLM"/>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
vLLM can be run on the cloud to scale to multiple GPUs with `SkyPilot <https://github.com/skypilot-org/skypilot>`__, an open-source framework for running LLMs on any cloud.
|
vLLM can be **run and scaled to multiple service replicas on clouds and Kubernetes** with `SkyPilot <https://github.com/skypilot-org/skypilot>`__, an open-source framework for running LLMs on any cloud. More examples for various open models, such as Llama-3, Mixtral, etc, can be found in `SkyPilot AI gallery <https://skypilot.readthedocs.io/en/latest/gallery/index.html>`__.
|
||||||
|
|
||||||
To install SkyPilot and setup your cloud credentials, run:
|
|
||||||
|
Prerequisites
|
||||||
|
-------------
|
||||||
|
|
||||||
|
- Go to the `HuggingFace model page <https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct>`__ and request access to the model :code:`meta-llama/Meta-Llama-3-8B-Instruct`.
|
||||||
|
- Check that you have installed SkyPilot (`docs <https://skypilot.readthedocs.io/en/latest/getting-started/installation.html>`__).
|
||||||
|
- Check that :code:`sky check` shows clouds or Kubernetes are enabled.
|
||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
$ pip install skypilot
|
pip install skypilot-nightly
|
||||||
$ sky check
|
sky check
|
||||||
|
|
||||||
|
|
||||||
|
Run on a single instance
|
||||||
|
------------------------
|
||||||
|
|
||||||
See the vLLM SkyPilot YAML for serving, `serving.yaml <https://github.com/skypilot-org/skypilot/blob/master/llm/vllm/serve.yaml>`__.
|
See the vLLM SkyPilot YAML for serving, `serving.yaml <https://github.com/skypilot-org/skypilot/blob/master/llm/vllm/serve.yaml>`__.
|
||||||
|
|
||||||
.. code-block:: yaml
|
.. code-block:: yaml
|
||||||
|
|
||||||
resources:
|
resources:
|
||||||
accelerators: A100
|
accelerators: {L4, A10g, A10, L40, A40, A100, A100-80GB} # We can use cheaper accelerators for 8B model.
|
||||||
|
use_spot: True
|
||||||
|
disk_size: 512 # Ensure model checkpoints can fit.
|
||||||
|
disk_tier: best
|
||||||
|
ports: 8081 # Expose to internet traffic.
|
||||||
|
|
||||||
envs:
|
envs:
|
||||||
MODEL_NAME: decapoda-research/llama-13b-hf
|
MODEL_NAME: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
TOKENIZER: hf-internal-testing/llama-tokenizer
|
HF_TOKEN: <your-huggingface-token> # Change to your own huggingface token, or use --env to pass.
|
||||||
|
|
||||||
setup: |
|
setup: |
|
||||||
conda create -n vllm python=3.9 -y
|
conda create -n vllm python=3.10 -y
|
||||||
conda activate vllm
|
conda activate vllm
|
||||||
git clone https://github.com/vllm-project/vllm.git
|
|
||||||
cd vllm
|
pip install vllm==0.4.0.post1
|
||||||
pip install .
|
# Install Gradio for web UI.
|
||||||
pip install gradio
|
pip install gradio openai
|
||||||
|
pip install flash-attn==2.5.7
|
||||||
|
|
||||||
run: |
|
run: |
|
||||||
conda activate vllm
|
conda activate vllm
|
||||||
echo 'Starting vllm api server...'
|
echo 'Starting vllm api server...'
|
||||||
python -u -m vllm.entrypoints.api_server \
|
python -u -m vllm.entrypoints.openai.api_server \
|
||||||
--model $MODEL_NAME \
|
--port 8081 \
|
||||||
--tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
|
--model $MODEL_NAME \
|
||||||
--tokenizer $TOKENIZER 2>&1 | tee api_server.log &
|
--trust-remote-code \
|
||||||
|
--tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
|
||||||
|
2>&1 | tee api_server.log &
|
||||||
|
|
||||||
echo 'Waiting for vllm api server to start...'
|
echo 'Waiting for vllm api server to start...'
|
||||||
while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do sleep 1; done
|
while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do sleep 1; done
|
||||||
echo 'Starting gradio server...'
|
|
||||||
python vllm/examples/gradio_webserver.py
|
|
||||||
|
|
||||||
Start the serving the LLaMA-13B model on an A100 GPU:
|
echo 'Starting gradio server...'
|
||||||
|
git clone https://github.com/vllm-project/vllm.git || true
|
||||||
|
python vllm/examples/gradio_openai_chatbot_webserver.py \
|
||||||
|
-m $MODEL_NAME \
|
||||||
|
--port 8811 \
|
||||||
|
--model-url http://localhost:8081/v1 \
|
||||||
|
--stop-token-ids 128009,128001
|
||||||
|
|
||||||
|
Start the serving the Llama-3 8B model on any of the candidate GPUs listed (L4, A10g, ...):
|
||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
$ sky launch serving.yaml
|
HF_TOKEN="your-huggingface-token" sky launch serving.yaml --env HF_TOKEN
|
||||||
|
|
||||||
Check the output of the command. There will be a shareable gradio link (like the last line of the following). Open it in your browser to use the LLaMA model to do the text completion.
|
Check the output of the command. There will be a shareable gradio link (like the last line of the following). Open it in your browser to use the LLaMA model to do the text completion.
|
||||||
|
|
||||||
@@ -61,9 +85,226 @@ Check the output of the command. There will be a shareable gradio link (like the
|
|||||||
|
|
||||||
(task, pid=7431) Running on public URL: https://<gradio-hash>.gradio.live
|
(task, pid=7431) Running on public URL: https://<gradio-hash>.gradio.live
|
||||||
|
|
||||||
**Optional**: Serve the 65B model instead of the default 13B and use more GPU:
|
**Optional**: Serve the 70B model instead of the default 8B and use more GPU:
|
||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
sky launch -c vllm-serve-new -s serve.yaml --gpus A100:8 --env MODEL_NAME=decapoda-research/llama-65b-hf
|
HF_TOKEN="your-huggingface-token" sky launch serving.yaml --gpus A100:8 --env HF_TOKEN --env MODEL_NAME=meta-llama/Meta-Llama-3-70B-Instruct
|
||||||
|
|
||||||
|
|
||||||
|
Scale up to multiple replicas
|
||||||
|
-----------------------------
|
||||||
|
|
||||||
|
SkyPilot can scale up the service to multiple service replicas with built-in autoscaling, load-balancing and fault-tolerance. You can do it by adding a services section to the YAML file.
|
||||||
|
|
||||||
|
.. code-block:: yaml
|
||||||
|
|
||||||
|
service:
|
||||||
|
replicas: 2
|
||||||
|
# An actual request for readiness probe.
|
||||||
|
readiness_probe:
|
||||||
|
path: /v1/chat/completions
|
||||||
|
post_data:
|
||||||
|
model: $MODEL_NAME
|
||||||
|
messages:
|
||||||
|
- role: user
|
||||||
|
content: Hello! What is your name?
|
||||||
|
max_tokens: 1
|
||||||
|
|
||||||
|
.. raw:: html
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Click to see the full recipe YAML</summary>
|
||||||
|
|
||||||
|
|
||||||
|
.. code-block:: yaml
|
||||||
|
|
||||||
|
service:
|
||||||
|
replicas: 2
|
||||||
|
# An actual request for readiness probe.
|
||||||
|
readiness_probe:
|
||||||
|
path: /v1/chat/completions
|
||||||
|
post_data:
|
||||||
|
model: $MODEL_NAME
|
||||||
|
messages:
|
||||||
|
- role: user
|
||||||
|
content: Hello! What is your name?
|
||||||
|
max_tokens: 1
|
||||||
|
|
||||||
|
resources:
|
||||||
|
accelerators: {L4, A10g, A10, L40, A40, A100, A100-80GB} # We can use cheaper accelerators for 8B model.
|
||||||
|
use_spot: True
|
||||||
|
disk_size: 512 # Ensure model checkpoints can fit.
|
||||||
|
disk_tier: best
|
||||||
|
ports: 8081 # Expose to internet traffic.
|
||||||
|
|
||||||
|
envs:
|
||||||
|
MODEL_NAME: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
HF_TOKEN: <your-huggingface-token> # Change to your own huggingface token, or use --env to pass.
|
||||||
|
|
||||||
|
setup: |
|
||||||
|
conda create -n vllm python=3.10 -y
|
||||||
|
conda activate vllm
|
||||||
|
|
||||||
|
pip install vllm==0.4.0.post1
|
||||||
|
# Install Gradio for web UI.
|
||||||
|
pip install gradio openai
|
||||||
|
pip install flash-attn==2.5.7
|
||||||
|
|
||||||
|
run: |
|
||||||
|
conda activate vllm
|
||||||
|
echo 'Starting vllm api server...'
|
||||||
|
python -u -m vllm.entrypoints.openai.api_server \
|
||||||
|
--port 8081 \
|
||||||
|
--model $MODEL_NAME \
|
||||||
|
--trust-remote-code \
|
||||||
|
--tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
|
||||||
|
2>&1 | tee api_server.log &
|
||||||
|
|
||||||
|
echo 'Waiting for vllm api server to start...'
|
||||||
|
while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do sleep 1; done
|
||||||
|
|
||||||
|
echo 'Starting gradio server...'
|
||||||
|
git clone https://github.com/vllm-project/vllm.git || true
|
||||||
|
python vllm/examples/gradio_openai_chatbot_webserver.py \
|
||||||
|
-m $MODEL_NAME \
|
||||||
|
--port 8811 \
|
||||||
|
--model-url http://localhost:8081/v1 \
|
||||||
|
--stop-token-ids 128009,128001
|
||||||
|
|
||||||
|
.. raw:: html
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
Start the serving the Llama-3 8B model on multiple replicas:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
HF_TOKEN="your-huggingface-token" sky serve up -n vllm serving.yaml --env HF_TOKEN
|
||||||
|
|
||||||
|
|
||||||
|
Wait until the service is ready:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
watch -n10 sky serve status vllm
|
||||||
|
|
||||||
|
|
||||||
|
.. raw:: html
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Example outputs:</summary>
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
Services
|
||||||
|
NAME VERSION UPTIME STATUS REPLICAS ENDPOINT
|
||||||
|
vllm 1 35s READY 2/2 xx.yy.zz.100:30001
|
||||||
|
|
||||||
|
Service Replicas
|
||||||
|
SERVICE_NAME ID VERSION IP LAUNCHED RESOURCES STATUS REGION
|
||||||
|
vllm 1 1 xx.yy.zz.121 18 mins ago 1x GCP({'L4': 1}) READY us-east4
|
||||||
|
vllm 2 1 xx.yy.zz.245 18 mins ago 1x GCP({'L4': 1}) READY us-east4
|
||||||
|
|
||||||
|
.. raw:: html
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
After the service is READY, you can find a single endpoint for the service and access the service with the endpoint:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
ENDPOINT=$(sky serve status --endpoint 8081 vllm)
|
||||||
|
curl -L http://$ENDPOINT/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Who are you?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stop_token_ids": [128009, 128001]
|
||||||
|
}'
|
||||||
|
|
||||||
|
To enable autoscaling, you could specify additional configs in `services`:
|
||||||
|
|
||||||
|
.. code-block:: yaml
|
||||||
|
|
||||||
|
services:
|
||||||
|
replica_policy:
|
||||||
|
min_replicas: 0
|
||||||
|
max_replicas: 3
|
||||||
|
target_qps_per_replica: 2
|
||||||
|
|
||||||
|
This will scale the service up to when the QPS exceeds 2 for each replica.
|
||||||
|
|
||||||
|
|
||||||
|
**Optional**: Connect a GUI to the endpoint
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|
||||||
|
It is also possible to access the Llama-3 service with a separate GUI frontend, so the user requests send to the GUI will be load-balanced across replicas.
|
||||||
|
|
||||||
|
.. raw:: html
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Click to see the full GUI YAML</summary>
|
||||||
|
|
||||||
|
.. code-block:: yaml
|
||||||
|
|
||||||
|
envs:
|
||||||
|
MODEL_NAME: meta-llama/Meta-Llama-3-70B-Instruct
|
||||||
|
ENDPOINT: x.x.x.x:3031 # Address of the API server running vllm.
|
||||||
|
|
||||||
|
resources:
|
||||||
|
cpus: 2
|
||||||
|
|
||||||
|
setup: |
|
||||||
|
conda activate vllm
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
conda create -n vllm python=3.10 -y
|
||||||
|
conda activate vllm
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Install Gradio for web UI.
|
||||||
|
pip install gradio openai
|
||||||
|
|
||||||
|
run: |
|
||||||
|
conda activate vllm
|
||||||
|
export PATH=$PATH:/sbin
|
||||||
|
WORKER_IP=$(hostname -I | cut -d' ' -f1)
|
||||||
|
CONTROLLER_PORT=21001
|
||||||
|
WORKER_PORT=21002
|
||||||
|
|
||||||
|
echo 'Starting gradio server...'
|
||||||
|
git clone https://github.com/vllm-project/vllm.git || true
|
||||||
|
python vllm/examples/gradio_openai_chatbot_webserver.py \
|
||||||
|
-m $MODEL_NAME \
|
||||||
|
--port 8811 \
|
||||||
|
--model-url http://$ENDPOINT/v1 \
|
||||||
|
--stop-token-ids 128009,128001 | tee ~/gradio.log
|
||||||
|
|
||||||
|
.. raw:: html
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
1. Start the chat web UI:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
sky launch -c gui ./gui.yaml --env ENDPOINT=$(sky serve status --endpoint vllm)
|
||||||
|
|
||||||
|
|
||||||
|
2. Then, we can access the GUI at the returned gradio link:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
| INFO | stdout | Running on public URL: https://6141e84201ce0bb4ed.gradio.live
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
46
examples/aqlm_example.py
Normal file
46
examples/aqlm_example.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='AQLM examples')
|
||||||
|
|
||||||
|
parser.add_argument('--model',
|
||||||
|
'-m',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='model path, as for HF')
|
||||||
|
parser.add_argument('--choice',
|
||||||
|
'-c',
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help='known good models by index, [0-4]')
|
||||||
|
parser.add_argument('--tensor_parallel_size',
|
||||||
|
'-t',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help='tensor parallel size')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
models = [
|
||||||
|
"ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf",
|
||||||
|
"ISTA-DASLab/Llama-2-7b-AQLM-2Bit-2x8-hf",
|
||||||
|
"ISTA-DASLab/Llama-2-13b-AQLM-2Bit-1x16-hf",
|
||||||
|
"ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf",
|
||||||
|
"BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf",
|
||||||
|
]
|
||||||
|
|
||||||
|
model = LLM(args.model if args.model is not None else models[args.choice],
|
||||||
|
tensor_parallel_size=args.tensor_parallel_size)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(max_tokens=100, temperature=0)
|
||||||
|
outputs = model.generate("Hello my name is",
|
||||||
|
sampling_params=sampling_params)
|
||||||
|
print(outputs[0].outputs[0].text)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
96
examples/fp8/README.md
Normal file
96
examples/fp8/README.md
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
# FP8 KV Cache
|
||||||
|
|
||||||
|
This utility extracts the KV cache scaling factors from a quantized HF (Hugging Face) model. The extracted scaling factors are saved to a JSON file, which can later be used by vLLM (variable-length language model) during runtime. This tool is particularly useful when the KV cache data type is FP8 and is intended for use on ROCm (AMD GPU) platforms.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
- Python 3.x
|
||||||
|
- PyTorch
|
||||||
|
- NumPy
|
||||||
|
- Hugging Face Transformers
|
||||||
|
- Hugging Face Hub
|
||||||
|
- AMMO
|
||||||
|
|
||||||
|
Before incorporating the FP8 datatype for inference workloads, you must adhere to the following steps:
|
||||||
|
1. Install all necessary prerequisites and dependencies.
|
||||||
|
2. Convert HF model into a quantized HF model.
|
||||||
|
3. Extract KV Cache Scaling Factors from quantized HF model.
|
||||||
|
4. Load KV Cache Scaling Factors into VLLM.
|
||||||
|
|
||||||
|
### 2. Convert HF model into a quantized HF model.
|
||||||
|
Note: The following steps are adapted from the [TensorRT-LLM repository](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/README.md).
|
||||||
|
|
||||||
|
`quantize.py` (examples/fp8/quantizer/quantize.py) uses the quantization toolkit (AMMO) to calibrate the PyTorch models and export TensorRT-LLM checkpoints. Each TensorRT-LLM checkpoint contains a config file (in .json format) and one or several rank weight files (in .safetensors format).
|
||||||
|
|
||||||
|
The detailed quantization toolkit (AMMO) conversion guide for FP8 can be found at `examples/fp8/quantizer/README.md`.
|
||||||
|
|
||||||
|
### 3. Extract KV Cache Scaling Factors from quantized HF model.
|
||||||
|
`extract_scales.py` (examples/fp8/extract_scales.py) can be utilized to extract the KV cache scaling factors from your quantized HF model, however at the moment, this tool exclusively supports Llama 2 models. It is also important to note the following:
|
||||||
|
1. **File Structure**: The utility operates under the assumption that all parameters, including KV cache scaling factors, corresponding to a particular Tensor Parallelism (TP) rank are stored in a single file. These files must adhere to a specific naming convention where the TP rank is immediately identified after a specific keyword (e.g., "rank") in the filename.
|
||||||
|
|
||||||
|
2. **TP Decomposition**: The utility assumes consistency between the TP decomposition employed by the quantizer tool and that used by vLLM.
|
||||||
|
|
||||||
|
3. **AMMO Compatibility**: Currently, the generated KV cache scaling factors for AMMO remain uniform across all TP ranks.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# prerequisites:
|
||||||
|
# - Quantized HF LLaMa 2 model
|
||||||
|
python3 examples/fp8/extract_scales.py --help
|
||||||
|
Usage: extract_scales.py [-h] --quantized_model QUANTIZED_MODEL [--load_format {auto,safetensors,npz,pt}] [--output_dir OUTPUT_DIR] [--output_name OUTPUT_NAME] [--tp_size TP_SIZE]
|
||||||
|
|
||||||
|
KV Scale Extraction Example
|
||||||
|
|
||||||
|
optional arguments:
|
||||||
|
--quantized_model: Specify either the local path to, or name of, a quantized HF model. It is expected that the quantization format is FP8_E4M3, for use on ROCm (AMD GPU).
|
||||||
|
Optional arguments:
|
||||||
|
--cache_dir: Specify a cache directory to use in the event of a HF model download. (Default: None)
|
||||||
|
--load_format: Specify the format of the model's tensor files containing the KV cache scaling factors. (Choices: auto, safetensors, npz, pt; Default: auto)
|
||||||
|
--revision: Specify the model's revision number. (Default: None)
|
||||||
|
--output_dir: Specify the output directory. By default the KV cache scaling factors will be saved in the model directory. (Default: None)
|
||||||
|
--output_name: Specify the output filename. (Default: kv_cache_scales.json)
|
||||||
|
--tp_size: Specify the tensor-parallel (TP) size that the quantized model should correspond to. If specified, during KV cache scaling factor extraction the observed TP size will be checked against this and an error will be raised if there is a mismatch. (Default: None)
|
||||||
|
```
|
||||||
|
```python
|
||||||
|
Example:
|
||||||
|
python3 examples/fp8/extract_scales.py --quantized_model <QUANTIZED_MODEL_DIR> --tp_size <TENSOR_PARALLEL_SIZE> --output_dir <PATH_TO_OUTPUT_DIR>
|
||||||
|
```
|
||||||
|
### 4. Load KV Cache Scaling Factors into VLLM.
|
||||||
|
This script evaluates the inference throughput of language models using various backends such as vLLM. It measures the time taken to process a given number of prompts and generate sequences for each prompt. The recently generated KV cache scaling factors are now integrated into the benchmarking process and allow for KV cache scaling factors to be utilized for FP8.
|
||||||
|
```python
|
||||||
|
# prerequisites:
|
||||||
|
# - LLaMa 2 kv_cache_scales.json file
|
||||||
|
|
||||||
|
python3 benchmarks/benchmark_throughput.py --help
|
||||||
|
usage: benchmark_throughput.py [-h] [--backend {vllm,hf,mii}] [--dataset DATASET] [--input-len INPUT_LEN] [--output-len OUTPUT_LEN] [--model MODEL]
|
||||||
|
[--tokenizer TOKENIZER] [--quantization {awq,gptq,squeezellm,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N]
|
||||||
|
[--use-beam-search] [--num-prompts NUM_PROMPTS] [--seed SEED] [--hf-max-batch-size HF_MAX_BATCH_SIZE] [--trust-remote-code]
|
||||||
|
[--max-model-len MAX_MODEL_LEN] [--dtype {auto,half,float16,bfloat16,float,float32}] [--enforce-eager] [--kv-cache-dtype {auto,fp8}]
|
||||||
|
[--quantization-param-path KV_CACHE_quantization_param_path]
|
||||||
|
|
||||||
|
Benchmark Throughput Example
|
||||||
|
optional arguments:
|
||||||
|
-h, --help show this help message and exit
|
||||||
|
--backend {vllm,hf,mii}
|
||||||
|
--dataset DATASET Path to the dataset.
|
||||||
|
--input-len INPUT_LEN Input prompt length for each request
|
||||||
|
--output-len OUTPUT_LEN Output length for each request. Overrides the output length from the dataset.
|
||||||
|
--model MODEL
|
||||||
|
--tokenizer TOKENIZER
|
||||||
|
--quantization {awq,gptq,squeezellm,None}, -q {awq,gptq,squeezellm,None}
|
||||||
|
--tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE
|
||||||
|
--n N Number of generated sequences per prompt.
|
||||||
|
--use-beam-search
|
||||||
|
--num-prompts NUM_PROMPTS Number of prompts to process.
|
||||||
|
--seed SEED
|
||||||
|
--hf-max-batch-size HF_MAX_BATCH_SIZE Maximum batch size for HF backend.
|
||||||
|
--trust-remote-code trust remote code from huggingface
|
||||||
|
--max-model-len MAX_MODEL_LEN Maximum length of a sequence (including prompt and output). If None, will be derived from the model.
|
||||||
|
--dtype {auto,half,float16,bfloat16,float,float32} data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.
|
||||||
|
--enforce-eager enforce eager execution
|
||||||
|
--kv-cache-dtype {auto,fp8} Data type for kv cache storage. If "auto", will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported ```for common inference criteria.
|
||||||
|
--quantization-param-path QUANT_PARAM_JSON Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria.
|
||||||
|
```
|
||||||
|
```
|
||||||
|
Example:
|
||||||
|
python3 benchmarks/benchmark_throughput.py --input-len <INPUT_LEN> --output-len <OUTPUT_LEN> -tp <TENSOR_PARALLEL_SIZE> --kv-cache-dtype fp8 --quantization-param-path <path/to/kv_cache_scales.json> --model <path-to-llama2>
|
||||||
|
```python
|
||||||
367
examples/fp8/extract_scales.py
Normal file
367
examples/fp8/extract_scales.py
Normal file
@@ -0,0 +1,367 @@
|
|||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import safe_open
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from vllm/model_executor/model_loader/weight_utils.py
|
||||||
|
# The main differences are that we add the NPZ format and simplify
|
||||||
|
# its functionality drastically for our purposes (e.g. we assume that
|
||||||
|
# the quantized model exists locally and there is no need to download it)
|
||||||
|
def _prepare_hf_weights(
|
||||||
|
quantized_model_dir: str,
|
||||||
|
load_format: str = "auto",
|
||||||
|
fall_back_to_pt: bool = True,
|
||||||
|
) -> Tuple[str, List[str], bool]:
|
||||||
|
if not os.path.isdir(quantized_model_dir):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"The quantized model directory `{quantized_model_dir}` "
|
||||||
|
"does not exist.")
|
||||||
|
use_safetensors = False
|
||||||
|
# Some quantized models use .pt files for storing the weights.
|
||||||
|
if load_format == "auto":
|
||||||
|
allow_patterns = ["*.safetensors", "*.bin"]
|
||||||
|
elif load_format == "safetensors":
|
||||||
|
use_safetensors = True
|
||||||
|
allow_patterns = ["*.safetensors"]
|
||||||
|
elif load_format == "pt":
|
||||||
|
allow_patterns = ["*.pt"]
|
||||||
|
elif load_format == "npz":
|
||||||
|
allow_patterns = ["*.npz"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown load_format: {load_format}")
|
||||||
|
if fall_back_to_pt:
|
||||||
|
allow_patterns += ["*.pt"]
|
||||||
|
|
||||||
|
hf_weights_files: List[str] = []
|
||||||
|
for pattern in allow_patterns:
|
||||||
|
hf_weights_files += glob.glob(
|
||||||
|
os.path.join(quantized_model_dir, pattern))
|
||||||
|
if len(hf_weights_files) > 0:
|
||||||
|
if pattern == "*.safetensors":
|
||||||
|
use_safetensors = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not use_safetensors:
|
||||||
|
# Exclude files that are not needed for inference.
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
|
||||||
|
blacklist = [
|
||||||
|
"training_args.bin",
|
||||||
|
"optimizer.bin",
|
||||||
|
"optimizer.pt",
|
||||||
|
"scheduler.pt",
|
||||||
|
"scaler.pt",
|
||||||
|
]
|
||||||
|
hf_weights_files = [
|
||||||
|
f for f in hf_weights_files
|
||||||
|
if not any(f.endswith(x) for x in blacklist)
|
||||||
|
]
|
||||||
|
|
||||||
|
if len(hf_weights_files) == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot find any model weights with `{quantized_model_dir}`")
|
||||||
|
|
||||||
|
return hf_weights_files, use_safetensors
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from vllm/model_executor/model_loader/weight_utils.py
|
||||||
|
def _hf_tensorfile_iterator(filename: str, load_format: str,
|
||||||
|
use_safetensors: bool):
|
||||||
|
if load_format == "npz":
|
||||||
|
assert not use_safetensors
|
||||||
|
with np.load(filename) as data:
|
||||||
|
for name in data.files:
|
||||||
|
param = torch.from_numpy(data[name])
|
||||||
|
yield name, param
|
||||||
|
elif use_safetensors:
|
||||||
|
with safe_open(filename, framework="pt") as f:
|
||||||
|
for name in f.keys(): # NOQA: SIM118
|
||||||
|
param = f.get_tensor(name)
|
||||||
|
yield name, param
|
||||||
|
else:
|
||||||
|
state = torch.load(filename, map_location="cpu")
|
||||||
|
for name, param in state.items():
|
||||||
|
yield name, param
|
||||||
|
del state
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def _kv_scales_extractor(
|
||||||
|
hf_tensor_files: Iterable[str],
|
||||||
|
use_safetensors: bool,
|
||||||
|
rank_keyword: str = "rank",
|
||||||
|
expected_tp_size: Optional[int] = None) -> Dict[int, Dict[int, float]]:
|
||||||
|
"""
|
||||||
|
Given a list of files containing tensor data, attempt to extract KV cache
|
||||||
|
scales from these files. Intended as a helper function taking in the output
|
||||||
|
from _prepare_hf_weights.
|
||||||
|
Args:
|
||||||
|
rank_keyword Matches the number immediately after this keyword in the
|
||||||
|
tensor filename to determine the TP rank corresponding
|
||||||
|
to said tensor file
|
||||||
|
expected_tp_size If specified, the TP size of the tensor files is checked
|
||||||
|
against this and an error is raised if they don't match.
|
||||||
|
Returns a dictionary mapping TP ranks to their relevant KV cache scales.
|
||||||
|
The per-rank scales are themselves represented as a dictionary of layer
|
||||||
|
indices to the respective per-layer scale.
|
||||||
|
"""
|
||||||
|
for char in rank_keyword:
|
||||||
|
assert not char.isdecimal(
|
||||||
|
), f"Rank keyword {rank_keyword} contains a numeric character!"
|
||||||
|
rank_scales_map = {}
|
||||||
|
for tensor_file in hf_tensor_files:
|
||||||
|
try:
|
||||||
|
rank_idx = tensor_file.find(rank_keyword)
|
||||||
|
if rank_idx != -1:
|
||||||
|
start_idx = rank_idx + len(rank_keyword)
|
||||||
|
stop_idx = start_idx
|
||||||
|
while stop_idx < len(
|
||||||
|
tensor_file) and tensor_file[stop_idx].isdecimal():
|
||||||
|
stop_idx += 1
|
||||||
|
if stop_idx == start_idx:
|
||||||
|
raise RuntimeError("Did not find rank # in filename.")
|
||||||
|
rank = int(tensor_file[start_idx:stop_idx])
|
||||||
|
elif len(hf_tensor_files) == 1:
|
||||||
|
# Since there is only one tensor file, we can assume
|
||||||
|
# that it's intended for TP rank 0
|
||||||
|
rank = 0
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Filename does not contain '{rank_keyword}'.")
|
||||||
|
except RuntimeError:
|
||||||
|
print("Unable to determine TP rank "
|
||||||
|
f"corresponding to file '{tensor_file}'")
|
||||||
|
raise
|
||||||
|
|
||||||
|
if rank not in rank_scales_map:
|
||||||
|
layer_scales_map = {}
|
||||||
|
rank_scales_map[rank] = layer_scales_map
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Tensor file '{tensor_file}' shares TP rank {rank} "
|
||||||
|
"with another tensor file.")
|
||||||
|
|
||||||
|
module_delimiter = ":" if args.load_format == "npz" else "."
|
||||||
|
for name, param in _hf_tensorfile_iterator(tensor_file,
|
||||||
|
args.load_format,
|
||||||
|
use_safetensors):
|
||||||
|
if "kv_cache_scaling_factor" in name:
|
||||||
|
nums = [
|
||||||
|
int(s) for s in name.split(module_delimiter)
|
||||||
|
if s.isdecimal()
|
||||||
|
]
|
||||||
|
assert len(
|
||||||
|
nums) == 1, f"Could not determine layer idx for {name}"
|
||||||
|
layer_idx = nums[0]
|
||||||
|
assert layer_idx not in layer_scales_map, f"Duplicate scaling"\
|
||||||
|
f" factor corresponding to layer {layer_idx}"
|
||||||
|
try:
|
||||||
|
layer_scales_map[layer_idx] = param.item()
|
||||||
|
except RuntimeError:
|
||||||
|
print(
|
||||||
|
"This utility supports only per-tensor scalar scales "
|
||||||
|
f"for now. The tensor\n {name} = {param} \nis an "
|
||||||
|
"invalid scale factor.")
|
||||||
|
raise
|
||||||
|
|
||||||
|
if all(
|
||||||
|
len(layer_scales_map) == 0
|
||||||
|
for layer_scales_map in rank_scales_map.values()):
|
||||||
|
# Note: this is true even if the rank_scales_map is empty
|
||||||
|
print("WARNING: No KV cache scale factors found. No output saved.")
|
||||||
|
return None
|
||||||
|
empirical_tp_world_size = max(rank_scales_map.keys()) + 1
|
||||||
|
if expected_tp_size is not None:
|
||||||
|
assert expected_tp_size == empirical_tp_world_size, \
|
||||||
|
f"User expected TP world size = {expected_tp_size} " \
|
||||||
|
"from model but tool is expecting TP world size = " \
|
||||||
|
f"{empirical_tp_world_size} from model instead."
|
||||||
|
for i in range(empirical_tp_world_size):
|
||||||
|
assert i in rank_scales_map, "Expected TP world size = "\
|
||||||
|
f"{empirical_tp_world_size} but did not find KV " \
|
||||||
|
f"cache scaling factors for TP rank {i}"
|
||||||
|
print(f"Found TP world size = {empirical_tp_world_size} "
|
||||||
|
"when extracting KV cache scales!")
|
||||||
|
return rank_scales_map
|
||||||
|
|
||||||
|
|
||||||
|
def _metadata_extractor(quantized_model_dir: str,
|
||||||
|
metadata_extract_fns: \
|
||||||
|
Dict[str, Callable[[Dict[str, Any]], Any]]) \
|
||||||
|
-> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Given a directory containing quantized model files, this function
|
||||||
|
aims to extract metadata from the JSON files within this directory.
|
||||||
|
Each JSON file is expected to represent a dictionary in JSON
|
||||||
|
format (referred to as a "JSON-dictionary"). Metadata extraction is
|
||||||
|
defined by a dictionary called metadata_extract_fns, where each
|
||||||
|
metadata field name is mapped to an extraction function.
|
||||||
|
|
||||||
|
These extraction functions are designed to take a JSON-dictionary
|
||||||
|
as their only argument and return the corresponding metadata.
|
||||||
|
While extraction functions are permitted to raise exceptions, they
|
||||||
|
should only raise a KeyError or ValueError if the metadata field
|
||||||
|
cannot be extracted from the current JSON-dictionary, yet there's
|
||||||
|
a possibility of finding it in another JSON-dictionary.
|
||||||
|
|
||||||
|
The function returns a dictionary that maps metadata fields to
|
||||||
|
their extracted data. The keys of this dictionary correspond exactly
|
||||||
|
to those in metadata_extract_fns. If any fields fail to be extracted,
|
||||||
|
their corresponding values are set to None, and a warning is printed.
|
||||||
|
"""
|
||||||
|
if not os.path.isdir(quantized_model_dir):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"The quantized model directory `{quantized_model_dir}` "
|
||||||
|
"does not exist.")
|
||||||
|
metadata_files = glob.glob(os.path.join(quantized_model_dir, "*.json"))
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for file in metadata_files:
|
||||||
|
with open(file) as f:
|
||||||
|
try:
|
||||||
|
metadata = json.load(f)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"Could not parse `{file}` as a valid metadata file,"
|
||||||
|
" skipping it.")
|
||||||
|
continue
|
||||||
|
if not isinstance(metadata, dict):
|
||||||
|
print(f"The file `{file}` does not correspond to a "
|
||||||
|
"JSON-serialized dictionary, skipping it.")
|
||||||
|
continue
|
||||||
|
for metadata_name, extract_fn in metadata_extract_fns.items():
|
||||||
|
try:
|
||||||
|
metadata_info = extract_fn(metadata)
|
||||||
|
if metadata_name not in result:
|
||||||
|
result[metadata_name] = metadata_info
|
||||||
|
elif metadata_info != result[metadata_name]:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Metadata mismatch! Originally found "
|
||||||
|
f"{metadata_name} = {result[metadata_name]} but "
|
||||||
|
f"now found {metadata_name} = {metadata_info} in "
|
||||||
|
f"`{file}`")
|
||||||
|
except KeyError:
|
||||||
|
# It is possible that a given file does not contain some
|
||||||
|
# of our selected metadata as it could be located in some
|
||||||
|
# other metadata file.
|
||||||
|
# 'EFINAE': extract_fn failure is not an error.
|
||||||
|
pass
|
||||||
|
except ValueError:
|
||||||
|
# See above.
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Warn if we cannot find any of the requested metadata
|
||||||
|
for metadata_name in metadata_extract_fns:
|
||||||
|
if metadata_name not in result:
|
||||||
|
print("WARNING: Unable to find requested metadata field "
|
||||||
|
f"`{metadata_name}`, setting it to None.")
|
||||||
|
result[metadata_name] = None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
metadata_extract_fns = {
|
||||||
|
"model_type": lambda json_dict: json_dict["layers"][0]["decoder_type"],
|
||||||
|
"tp_size": lambda json_dict: int(json_dict["tensor_parallel"]),
|
||||||
|
"model_dtype": lambda json_dict: json_dict["dtype"]
|
||||||
|
}
|
||||||
|
recovered_metadata = _metadata_extractor(args.quantized_model,
|
||||||
|
metadata_extract_fns)
|
||||||
|
if args.tp_size is not None:
|
||||||
|
metadata_tp_size = recovered_metadata["tp_size"]
|
||||||
|
if metadata_tp_size is not None:
|
||||||
|
assert args.tp_size == metadata_tp_size, \
|
||||||
|
f"User expected TP world size = {args.tp_size} " \
|
||||||
|
f"but found TP world size = {metadata_tp_size} from metadata!"
|
||||||
|
expected_tp_size = args.tp_size or recovered_metadata["tp_size"]
|
||||||
|
rank_keyword = "rank"
|
||||||
|
hf_tensor_files, use_safetensors = _prepare_hf_weights(
|
||||||
|
args.quantized_model, args.load_format)
|
||||||
|
rank_scales_map = _kv_scales_extractor(hf_tensor_files, use_safetensors,
|
||||||
|
rank_keyword, expected_tp_size)
|
||||||
|
# Postprocess: formatting to the current schema. Consider pulling it
|
||||||
|
# out into a dedicated function should it ever become more complicated.
|
||||||
|
rank_scales_map = {
|
||||||
|
rank: {k: scale[k]
|
||||||
|
for k in sorted(scale.keys())}
|
||||||
|
for rank, scale in rank_scales_map.items()
|
||||||
|
}
|
||||||
|
# TODO: Expand this with activation and weights scaling factors when
|
||||||
|
# they are used in the future
|
||||||
|
schema = QuantParamSchema(
|
||||||
|
model_type=recovered_metadata["model_type"],
|
||||||
|
kv_cache={
|
||||||
|
"dtype": ("float8_e4m3fn" if len(rank_scales_map) > 0 else
|
||||||
|
recovered_metadata["model_dtype"]),
|
||||||
|
"scaling_factor":
|
||||||
|
rank_scales_map
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.output_dir is None:
|
||||||
|
output_file = os.path.join(args.quantized_model, args.output_name)
|
||||||
|
else:
|
||||||
|
if not os.path.isdir(args.output_dir):
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
output_file = os.path.join(args.output_dir, args.output_name)
|
||||||
|
|
||||||
|
with open(output_file, 'w') as f:
|
||||||
|
f.write(schema.model_dump_json(indent=4))
|
||||||
|
print(f"Completed! KV cache scaling factors saved to {output_file}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="This simple utility extracts the "
|
||||||
|
"KV cache scaling factors from a quantized HF model "
|
||||||
|
"and saves them to a JSON file compatible with later "
|
||||||
|
"use by vLLM (pass this file to the appropriate "
|
||||||
|
"runtime typically using the argument "
|
||||||
|
"--quantization-param-path <filename>). This is only used "
|
||||||
|
"if the KV cache dtype is FP8 and on ROCm (AMD GPU).")
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantized_model",
|
||||||
|
help="Specify the directory containing a single quantized HF model. "
|
||||||
|
"It is expected that the quantization format is FP8_E4M3, for use "
|
||||||
|
"on ROCm (AMD GPU).",
|
||||||
|
required=True)
|
||||||
|
parser.add_argument(
|
||||||
|
"--load_format",
|
||||||
|
help="Optionally specify the format of the model's tensor files "
|
||||||
|
"containing the KV cache scaling factors.",
|
||||||
|
choices=["auto", "safetensors", "npz", "pt"],
|
||||||
|
default="auto")
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
help="Optionally specify the output directory. By default the "
|
||||||
|
"KV cache scaling factors will be saved in the model directory, "
|
||||||
|
"however you can override this behavior here.",
|
||||||
|
default=None)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_name",
|
||||||
|
help="Optionally specify the output filename.",
|
||||||
|
# TODO: Change this once additional scaling factors are enabled
|
||||||
|
default="kv_cache_scales.json")
|
||||||
|
parser.add_argument(
|
||||||
|
"--tp_size",
|
||||||
|
help="Optionally specify the tensor-parallel (TP) size that the "
|
||||||
|
"quantized model should correspond to. If specified, during KV "
|
||||||
|
"cache scaling factor extraction the observed TP size will be "
|
||||||
|
"checked against this and an error will be raised if there is "
|
||||||
|
"a mismatch. If not specified, the quantized model's expected "
|
||||||
|
"TP size is instead inferred from the largest TP rank observed. "
|
||||||
|
"The expected TP size is cross-checked against the TP ranks "
|
||||||
|
"observed in the quantized model and an error is raised if any "
|
||||||
|
"discrepancies are found.",
|
||||||
|
default=None,
|
||||||
|
type=int)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
32
examples/fp8/quantizer/README.md
Normal file
32
examples/fp8/quantizer/README.md
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
### Quantizer Utilities
|
||||||
|
`quantize.py`: NVIDIA Quantization utilities using AMMO, ported from TensorRT-LLM:
|
||||||
|
`https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py`
|
||||||
|
|
||||||
|
### Prerequisite
|
||||||
|
|
||||||
|
#### AMMO (AlgorithMic Model Optimization) Installation: nvidia-ammo 0.7.1 or later
|
||||||
|
`pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo`
|
||||||
|
|
||||||
|
#### AMMO Download (code and docs)
|
||||||
|
`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.5.0.tar.gz`
|
||||||
|
`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.7.1.tar.gz`
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
#### Run on H100 system for speed if FP8; number of GPUs depends on the model size
|
||||||
|
|
||||||
|
#### Example: quantize Llama2-7b model from HF to FP8 with FP8 KV Cache:
|
||||||
|
`python quantize.py --model_dir ./ll2-7b --dtype float16 --qformat fp8 --kv_cache_dtype fp8 --output_dir ./ll2_7b_fp8 --calib_size 512 --tp_size 1`
|
||||||
|
|
||||||
|
Outputs: model structure, quantized model & parameters (with scaling factors) are in JSON and Safetensors (npz is generated only for the reference)
|
||||||
|
```
|
||||||
|
# ll ./ll2_7b_fp8/
|
||||||
|
total 19998244
|
||||||
|
drwxr-xr-x 2 root root 4096 Feb 7 01:08 ./
|
||||||
|
drwxrwxr-x 8 1060 1061 4096 Feb 7 01:08 ../
|
||||||
|
-rw-r--r-- 1 root root 176411 Feb 7 01:08 llama_tp1.json
|
||||||
|
-rw-r--r-- 1 root root 13477087480 Feb 7 01:09 llama_tp1_rank0.npz
|
||||||
|
-rw-r--r-- 1 root root 7000893272 Feb 7 01:08 rank0.safetensors
|
||||||
|
#
|
||||||
|
```
|
||||||
|
|
||||||
367
examples/fp8/quantizer/quantize.py
Normal file
367
examples/fp8/quantizer/quantize.py
Normal file
@@ -0,0 +1,367 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # noqa: E501
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Adapted from examples/quantization/hf_ptq.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
import ammo.torch.quantization as atq
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from ammo.torch.export import export_model_config
|
||||||
|
from datasets import load_dataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
RAND_SEED = 1234
|
||||||
|
MAX_SEQ_LEN = 2048
|
||||||
|
|
||||||
|
EMPTY_CFG = {
|
||||||
|
"quant_cfg": {
|
||||||
|
"*weight_quantizer": {
|
||||||
|
"enable": False,
|
||||||
|
},
|
||||||
|
"*input_quantizer": {
|
||||||
|
"enable": False
|
||||||
|
},
|
||||||
|
"*lm_head*": {
|
||||||
|
"enable": False
|
||||||
|
},
|
||||||
|
"*output_layer*": {
|
||||||
|
"enable": False
|
||||||
|
},
|
||||||
|
"default": {
|
||||||
|
"enable": False
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"algorithm": "max",
|
||||||
|
}
|
||||||
|
|
||||||
|
KV_CACHE_CFG = {
|
||||||
|
"*.query_key_value.output_quantizer": {
|
||||||
|
"num_bits": 8,
|
||||||
|
"axis": None,
|
||||||
|
"enable": True
|
||||||
|
},
|
||||||
|
"*.Wqkv.output_quantizer": {
|
||||||
|
"num_bits": 8,
|
||||||
|
"axis": None,
|
||||||
|
"enable": True
|
||||||
|
},
|
||||||
|
"*.W_pack.output_quantizer": {
|
||||||
|
"num_bits": 8,
|
||||||
|
"axis": None,
|
||||||
|
"enable": True
|
||||||
|
},
|
||||||
|
"*.c_attn.output_quantizer": {
|
||||||
|
"num_bits": 8,
|
||||||
|
"axis": None,
|
||||||
|
"enable": True
|
||||||
|
},
|
||||||
|
"*.k_proj.output_quantizer": {
|
||||||
|
"num_bits": 8,
|
||||||
|
"axis": None,
|
||||||
|
"enable": True
|
||||||
|
},
|
||||||
|
"*.v_proj.output_quantizer": {
|
||||||
|
"num_bits": 8,
|
||||||
|
"axis": None,
|
||||||
|
"enable": True
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
QUANT_CFG_CHOICES = {
|
||||||
|
"int8_sq": atq.INT8_SMOOTHQUANT_CFG,
|
||||||
|
"fp8": atq.FP8_DEFAULT_CFG,
|
||||||
|
"int4_awq": atq.INT4_AWQ_CFG,
|
||||||
|
"w4a8_awq": atq.W4A8_AWQ_BETA_CFG,
|
||||||
|
"int8_wo": EMPTY_CFG,
|
||||||
|
"int4_wo": EMPTY_CFG,
|
||||||
|
"full_prec": EMPTY_CFG,
|
||||||
|
}
|
||||||
|
|
||||||
|
MODEL_NAME_PATTERN_MAP = {
|
||||||
|
"GPT2": "gpt2",
|
||||||
|
"Xverse": "llama",
|
||||||
|
"Llama": "llama",
|
||||||
|
"Mistral": "llama",
|
||||||
|
"GPTJ": "gptj",
|
||||||
|
"FalconForCausalLM": "falcon",
|
||||||
|
"RWForCausalLM": "falcon",
|
||||||
|
"baichuan": "baichuan",
|
||||||
|
"MPT": "mpt",
|
||||||
|
"Bloom": "bloom",
|
||||||
|
"ChatGLM": "chatglm",
|
||||||
|
"QWen": "qwen",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokenizer(ckpt_path, max_seq_len=MAX_SEQ_LEN, model_type=None):
|
||||||
|
print(f"Initializing tokenizer from {ckpt_path}")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
ckpt_path,
|
||||||
|
model_max_length=max_seq_len,
|
||||||
|
padding_side="left",
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
if model_type and model_type == "qwen":
|
||||||
|
# qwen use token id 151643 as pad and eos tokens
|
||||||
|
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(151643)
|
||||||
|
tokenizer.eos_token = tokenizer.convert_ids_to_tokens(151643)
|
||||||
|
|
||||||
|
# can't set attribute 'pad_token' for "<unk>"
|
||||||
|
if tokenizer.pad_token != "<unk>":
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
assert (tokenizer.pad_token
|
||||||
|
is not None), f"Pad token for {model_type} cannot be set!"
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(ckpt_path, dtype="fp16", device="cuda"):
|
||||||
|
print(f"Initializing model from {ckpt_path}")
|
||||||
|
if dtype == "bf16" or dtype == "bfloat16":
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
elif dtype == "fp16" or dtype == "float16":
|
||||||
|
dtype = torch.float16
|
||||||
|
elif dtype == "fp32" or dtype == "float32":
|
||||||
|
dtype = torch.float32
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown dtype {dtype}")
|
||||||
|
|
||||||
|
# model_kwargs = {"torch_dtype": dtype}
|
||||||
|
model_kwargs = {"torch_dtype": "auto"}
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(ckpt_path,
|
||||||
|
device_map="auto",
|
||||||
|
**model_kwargs,
|
||||||
|
trust_remote_code=True)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
model_dtype = next(model.parameters()).dtype
|
||||||
|
if dtype != model_dtype:
|
||||||
|
print("[TensorRT-LLM][WARNING] The manually set model data type is "
|
||||||
|
f"{dtype}, but the data type of the HuggingFace model is "
|
||||||
|
f"{model_dtype}.")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_type(model):
|
||||||
|
for k, v in MODEL_NAME_PATTERN_MAP.items():
|
||||||
|
if k.lower() in type(model).__name__.lower():
|
||||||
|
return v
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_calib_dataloader(data="cnn_dailymail",
|
||||||
|
tokenizer=None,
|
||||||
|
batch_size=1,
|
||||||
|
calib_size=512,
|
||||||
|
block_size=512,
|
||||||
|
device=None):
|
||||||
|
print("Loading calibration dataset")
|
||||||
|
if data == "pileval":
|
||||||
|
dataset = load_dataset(
|
||||||
|
"json",
|
||||||
|
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
|
||||||
|
split="train")
|
||||||
|
dataset = dataset["text"][:calib_size]
|
||||||
|
elif data == "cnn_dailymail":
|
||||||
|
dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train")
|
||||||
|
dataset = dataset["article"][:calib_size]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
batch_encoded = tokenizer.batch_encode_plus(dataset,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
max_length=block_size)
|
||||||
|
if device:
|
||||||
|
batch_encoded = batch_encoded.to(device)
|
||||||
|
batch_encoded = batch_encoded["input_ids"]
|
||||||
|
|
||||||
|
calib_dataloader = DataLoader(batch_encoded,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False)
|
||||||
|
|
||||||
|
return calib_dataloader
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_model(model, quant_cfg, calib_dataloader=None):
|
||||||
|
|
||||||
|
def calibrate_loop():
|
||||||
|
if calib_dataloader is None:
|
||||||
|
return
|
||||||
|
"""Adjusts weights and scaling factors based on selected algorithms."""
|
||||||
|
for idx, data in enumerate(calib_dataloader):
|
||||||
|
print(f"Calibrating batch {idx}")
|
||||||
|
model(data)
|
||||||
|
|
||||||
|
print("Starting quantization...")
|
||||||
|
start_time = time.time()
|
||||||
|
atq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
|
||||||
|
end_time = time.time()
|
||||||
|
print("Quantization done. Total time used: {:.2f} s.".format(end_time -
|
||||||
|
start_time))
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise EnvironmentError("GPU is required for inference.")
|
||||||
|
|
||||||
|
random.seed(RAND_SEED)
|
||||||
|
np.random.seed(RAND_SEED)
|
||||||
|
|
||||||
|
model = get_model(args.model_dir, args.dtype, args.device)
|
||||||
|
model_type = get_model_type(model)
|
||||||
|
tokenizer = get_tokenizer(args.model_dir, model_type=model_type)
|
||||||
|
|
||||||
|
if args.qformat in ["full_prec", "int8_wo", "int4_wo"
|
||||||
|
] and args.kv_cache_dtype is None:
|
||||||
|
print(f"No quantization applied, export {args.dtype} model")
|
||||||
|
else:
|
||||||
|
if "awq" in args.qformat:
|
||||||
|
if args.calib_size > 32:
|
||||||
|
print("AWQ calibration could take longer with calib_size = "
|
||||||
|
f"{args.calib_size}, Using calib_size=32 instead")
|
||||||
|
args.calib_size = 32
|
||||||
|
print("\nAWQ calibration could take longer than other calibration "
|
||||||
|
"methods. Please increase the batch size to speed up the "
|
||||||
|
"calibration process. Batch size can be set by adding the "
|
||||||
|
"argument --batch_size <batch_size> to the command line.\n")
|
||||||
|
|
||||||
|
calib_dataloader = get_calib_dataloader(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
calib_size=args.calib_size,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.qformat in QUANT_CFG_CHOICES:
|
||||||
|
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported quantization format: {args.qformat}")
|
||||||
|
|
||||||
|
if "awq" in args.qformat:
|
||||||
|
quant_cfg = copy.deepcopy(QUANT_CFG_CHOICES[args.qformat])
|
||||||
|
weight_quantizer = quant_cfg["quant_cfg"][
|
||||||
|
"*weight_quantizer"] # type: ignore
|
||||||
|
if isinstance(weight_quantizer, list):
|
||||||
|
weight_quantizer = weight_quantizer[0]
|
||||||
|
weight_quantizer["block_sizes"][-1] = args.awq_block_size
|
||||||
|
|
||||||
|
if args.kv_cache_dtype is not None:
|
||||||
|
if args.kv_cache_dtype == "fp8":
|
||||||
|
for value in KV_CACHE_CFG.values():
|
||||||
|
value.update({"num_bits": (4, 3)}) # type: ignore
|
||||||
|
quant_cfg["quant_cfg"].update(KV_CACHE_CFG) # type: ignore
|
||||||
|
|
||||||
|
print(quant_cfg)
|
||||||
|
|
||||||
|
model = quantize_model(model, quant_cfg, calib_dataloader)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
if model_type is None:
|
||||||
|
print(f"Unknown model type {type(model).__name__}. Continue "
|
||||||
|
"exporting...")
|
||||||
|
model_type = f"unknown:{type(model).__name__}"
|
||||||
|
|
||||||
|
export_path = args.output_dir
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
if args.qformat == "int4_awq" and model_type == "qwen":
|
||||||
|
torch.save(model.state_dict(), export_path)
|
||||||
|
else:
|
||||||
|
export_npz = (model_type not in [
|
||||||
|
'gptj', 'falcon', 'chatglm', 'mpt', 'llama', 'baichuan'
|
||||||
|
])
|
||||||
|
|
||||||
|
# export safetensors
|
||||||
|
export_model_config(
|
||||||
|
model,
|
||||||
|
model_type,
|
||||||
|
getattr(torch, args.dtype),
|
||||||
|
export_dir=export_path,
|
||||||
|
inference_tensor_parallel=args.tp_size,
|
||||||
|
inference_pipeline_parallel=args.pp_size,
|
||||||
|
# export_tensorrt_llm_config=(not export_npz),
|
||||||
|
export_tensorrt_llm_config=False,
|
||||||
|
export_npz=export_npz)
|
||||||
|
|
||||||
|
# Workaround for wo quantization
|
||||||
|
if args.qformat in ["int8_wo", "int4_wo", "full_prec"]:
|
||||||
|
with open(f"{export_path}/config.json", 'r') as f:
|
||||||
|
tensorrt_llm_config = json.load(f)
|
||||||
|
if args.qformat == "int8_wo":
|
||||||
|
tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16'
|
||||||
|
elif args.qformat == "int4_wo":
|
||||||
|
tensorrt_llm_config["quantization"]["quant_algo"] = 'W4A16'
|
||||||
|
else:
|
||||||
|
tensorrt_llm_config["quantization"]["quant_algo"] = None
|
||||||
|
with open(f"{export_path}/config.json", "w") as f:
|
||||||
|
json.dump(tensorrt_llm_config, f, indent=4)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
print("Quantized model exported to {} \nTotal time used {:.2f} s.".
|
||||||
|
format(export_path, end_time - start_time))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument("--model_dir",
|
||||||
|
help="Specify where the HuggingFace model is",
|
||||||
|
required=True)
|
||||||
|
parser.add_argument("--device", default="cuda")
|
||||||
|
parser.add_argument("--dtype", help="Model data type.", default="float16")
|
||||||
|
parser.add_argument(
|
||||||
|
"--qformat",
|
||||||
|
help="Quantization format.",
|
||||||
|
default="full_prec",
|
||||||
|
choices=[
|
||||||
|
"fp8", "int8_sq", "int4_awq", "w4a8_awq", "int8_wo", "int4_wo",
|
||||||
|
"full_prec"
|
||||||
|
],
|
||||||
|
)
|
||||||
|
parser.add_argument("--batch_size",
|
||||||
|
help="Batch size for calibration.",
|
||||||
|
type=int,
|
||||||
|
default=1)
|
||||||
|
parser.add_argument("--calib_size",
|
||||||
|
help="Number of samples for calibration.",
|
||||||
|
type=int,
|
||||||
|
default=512)
|
||||||
|
parser.add_argument("--output_dir", default="exported_model")
|
||||||
|
parser.add_argument("--tp_size", type=int, default=1)
|
||||||
|
parser.add_argument("--pp_size", type=int, default=1)
|
||||||
|
parser.add_argument("--awq_block_size", type=int, default=128)
|
||||||
|
parser.add_argument("--kv_cache_dtype",
|
||||||
|
help="KV Cache dtype.",
|
||||||
|
default=None,
|
||||||
|
choices=["int8", "fp8", None])
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
282
examples/tensorize_vllm_model.py
Normal file
282
examples/tensorize_vllm_model.py
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
import argparse
|
||||||
|
import dataclasses
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from functools import partial
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
|
||||||
|
TensorSerializer, stream_io)
|
||||||
|
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
|
||||||
|
from transformers import AutoConfig, PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.distributed import initialize_model_parallel
|
||||||
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
|
from vllm.model_executor.model_loader.tensorizer import TensorizerArgs
|
||||||
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
|
||||||
|
# yapf conflicts with isort for this docstring
|
||||||
|
# yapf: disable
|
||||||
|
"""
|
||||||
|
tensorize_vllm_model.py is a script that can be used to serialize and
|
||||||
|
deserialize vLLM models. These models can be loaded using tensorizer
|
||||||
|
to the GPU extremely quickly over an HTTP/HTTPS endpoint, an S3 endpoint,
|
||||||
|
or locally. Tensor encryption and decryption is also supported, although
|
||||||
|
libsodium must be installed to use it. Install vllm with tensorizer support
|
||||||
|
using `pip install vllm[tensorizer]`.
|
||||||
|
|
||||||
|
To serialize a model, install vLLM from source, then run something
|
||||||
|
like this from the root level of this repository:
|
||||||
|
|
||||||
|
python -m examples.tensorize_vllm_model \
|
||||||
|
--model EleutherAI/gpt-j-6B \
|
||||||
|
--dtype float16 \
|
||||||
|
serialize \
|
||||||
|
--serialized-directory s3://my-bucket/ \
|
||||||
|
--suffix vllm
|
||||||
|
|
||||||
|
Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
|
||||||
|
and saves it to your S3 bucket. A local directory can also be used. This
|
||||||
|
assumes your S3 credentials are specified as environment variables
|
||||||
|
in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
|
||||||
|
To provide S3 credentials directly, you can provide `--s3-access-key-id` and
|
||||||
|
`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this
|
||||||
|
script.
|
||||||
|
|
||||||
|
You can also encrypt the model weights with a randomly-generated key by
|
||||||
|
providing a `--keyfile` argument.
|
||||||
|
|
||||||
|
To deserialize a model, you can run something like this from the root
|
||||||
|
level of this repository:
|
||||||
|
|
||||||
|
python -m examples.tensorize_vllm_model \
|
||||||
|
--model EleutherAI/gpt-j-6B \
|
||||||
|
--dtype float16 \
|
||||||
|
deserialize \
|
||||||
|
--path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
|
||||||
|
|
||||||
|
Which downloads the model tensors from your S3 bucket and deserializes them.
|
||||||
|
|
||||||
|
You can also provide a `--keyfile` argument to decrypt the model weights if
|
||||||
|
they were serialized with encryption.
|
||||||
|
|
||||||
|
For more information on the available arguments for serializing, run
|
||||||
|
`python -m examples.tensorize_vllm_model serialize --help`.
|
||||||
|
|
||||||
|
Or for deserializing:
|
||||||
|
|
||||||
|
`python -m examples.tensorize_vllm_model deserialize --help`.
|
||||||
|
|
||||||
|
Once a model is serialized, it can be used to load the model when running the
|
||||||
|
OpenAI inference client at `vllm/entrypoints/openai/api_server.py` by providing
|
||||||
|
the `--tensorizer-uri` CLI argument that is functionally the same as the
|
||||||
|
`--path-to-tensors` argument in this script, along with `--vllm-tensorized`, to
|
||||||
|
signify that the model to be deserialized is a vLLM model, rather than a
|
||||||
|
HuggingFace `PreTrainedModel`, which can also be deserialized using tensorizer
|
||||||
|
in the same inference server, albeit without the speed optimizations. To
|
||||||
|
deserialize an encrypted file, the `--encryption-keyfile` argument can be used
|
||||||
|
to provide the path to the keyfile used to encrypt the model weights. For
|
||||||
|
information on all the arguments that can be used to configure tensorizer's
|
||||||
|
deserialization, check out the tensorizer options argument group in the
|
||||||
|
`vllm/entrypoints/openai/api_server.py` script with `--help`.
|
||||||
|
|
||||||
|
Tensorizer can also be invoked with the `LLM` class directly to load models:
|
||||||
|
|
||||||
|
llm = LLM(model="facebook/opt-125m",
|
||||||
|
load_format="tensorizer",
|
||||||
|
tensorizer_uri=path_to_opt_tensors,
|
||||||
|
num_readers=3,
|
||||||
|
vllm_tensorized=True)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="An example script that can be used to serialize and "
|
||||||
|
"deserialize vLLM models. These models "
|
||||||
|
"can be loaded using tensorizer directly to the GPU "
|
||||||
|
"extremely quickly. Tensor encryption and decryption is "
|
||||||
|
"also supported, although libsodium must be installed to "
|
||||||
|
"use it.")
|
||||||
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
|
subparsers = parser.add_subparsers(dest='command')
|
||||||
|
|
||||||
|
serialize_parser = subparsers.add_parser(
|
||||||
|
'serialize', help="Serialize a model to `--serialized-directory`")
|
||||||
|
|
||||||
|
serialize_parser.add_argument(
|
||||||
|
"--suffix",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help=(
|
||||||
|
"The suffix to append to the serialized model directory, which is "
|
||||||
|
"used to construct the location of the serialized model tensors, "
|
||||||
|
"e.g. if `--serialized-directory` is `s3://my-bucket/` and "
|
||||||
|
"`--suffix` is `v1`, the serialized model tensors will be "
|
||||||
|
"saved to "
|
||||||
|
"`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
|
||||||
|
"If none is provided, a random UUID will be used."))
|
||||||
|
serialize_parser.add_argument(
|
||||||
|
"--serialized-directory",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The directory to serialize the model to. "
|
||||||
|
"This can be a local directory or S3 URI. The path to where the "
|
||||||
|
"tensors are saved is a combination of the supplied `dir` and model "
|
||||||
|
"reference ID. For instance, if `dir` is the serialized directory, "
|
||||||
|
"and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will "
|
||||||
|
"be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, "
|
||||||
|
"where `suffix` is given by `--suffix` or a random UUID if not "
|
||||||
|
"provided.")
|
||||||
|
|
||||||
|
serialize_parser.add_argument(
|
||||||
|
"--keyfile",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help=("Encrypt the model weights with a randomly-generated binary key,"
|
||||||
|
" and save the key at this path"))
|
||||||
|
|
||||||
|
deserialize_parser = subparsers.add_parser(
|
||||||
|
'deserialize',
|
||||||
|
help=("Deserialize a model from `--path-to-tensors`"
|
||||||
|
" to verify it can be loaded and used."))
|
||||||
|
|
||||||
|
deserialize_parser.add_argument(
|
||||||
|
"--path-to-tensors",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The local path or S3 URI to the model tensors to deserialize. ")
|
||||||
|
|
||||||
|
deserialize_parser.add_argument(
|
||||||
|
"--keyfile",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help=("Path to a binary key to use to decrypt the model weights,"
|
||||||
|
" if the model was serialized with encryption"))
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def make_model_contiguous(model):
|
||||||
|
# Ensure tensors are saved in memory contiguously
|
||||||
|
for param in model.parameters():
|
||||||
|
param.data = param.data.contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||||
|
architectures = getattr(config, "architectures", [])
|
||||||
|
for arch in architectures:
|
||||||
|
model_cls = ModelRegistry.load_model_cls(arch)
|
||||||
|
if model_cls is not None:
|
||||||
|
return model_cls
|
||||||
|
raise ValueError(
|
||||||
|
f"Model architectures {architectures} are not supported for now. "
|
||||||
|
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||||
|
|
||||||
|
|
||||||
|
def serialize():
|
||||||
|
|
||||||
|
eng_args_dict = {f.name: getattr(args, f.name) for f in
|
||||||
|
dataclasses.fields(EngineArgs)}
|
||||||
|
engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
|
||||||
|
engine = LLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
|
model = (engine.model_executor.driver_worker.
|
||||||
|
model_runner.model)
|
||||||
|
|
||||||
|
encryption_params = EncryptionParams.random() if keyfile else None
|
||||||
|
if keyfile:
|
||||||
|
with _write_stream(keyfile) as stream:
|
||||||
|
stream.write(encryption_params.key)
|
||||||
|
|
||||||
|
with _write_stream(model_path) as stream:
|
||||||
|
serializer = TensorSerializer(stream, encryption=encryption_params)
|
||||||
|
serializer.write_module(model)
|
||||||
|
serializer.close()
|
||||||
|
|
||||||
|
print("Serialization complete. Model tensors saved to", model_path)
|
||||||
|
if keyfile:
|
||||||
|
print("Key saved to", keyfile)
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize():
|
||||||
|
config = AutoConfig.from_pretrained(model_ref)
|
||||||
|
|
||||||
|
with no_init_or_tensor():
|
||||||
|
model_class = _get_vllm_model_architecture(config)
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
before_mem = get_mem_usage()
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
if keyfile:
|
||||||
|
with _read_stream(keyfile) as stream:
|
||||||
|
key = stream.read()
|
||||||
|
decryption_params = DecryptionParams.from_key(key)
|
||||||
|
tensorizer_args.deserializer_params['encryption'] = \
|
||||||
|
decryption_params
|
||||||
|
|
||||||
|
with (_read_stream(model_path)) as stream, TensorDeserializer(
|
||||||
|
stream, **tensorizer_args.deserializer_params) as deserializer:
|
||||||
|
deserializer.load_into_module(model)
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
# Brag about how fast we are.
|
||||||
|
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
|
||||||
|
duration = end - start
|
||||||
|
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
|
||||||
|
after_mem = get_mem_usage()
|
||||||
|
print(
|
||||||
|
f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s"
|
||||||
|
)
|
||||||
|
print(f"Memory usage before: {before_mem}")
|
||||||
|
print(f"Memory usage after: {after_mem}")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID")
|
||||||
|
or None)
|
||||||
|
s3_secret_access_key = (args.s3_secret_access_key
|
||||||
|
or os.environ.get("S3_SECRET_ACCESS_KEY") or None)
|
||||||
|
|
||||||
|
s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None)
|
||||||
|
|
||||||
|
_read_stream, _write_stream = (partial(
|
||||||
|
stream_io.open_stream,
|
||||||
|
mode=mode,
|
||||||
|
s3_access_key_id=s3_access_key_id,
|
||||||
|
s3_secret_access_key=s3_secret_access_key,
|
||||||
|
s3_endpoint=s3_endpoint,
|
||||||
|
) for mode in ("rb", "wb+"))
|
||||||
|
|
||||||
|
model_ref = args.model
|
||||||
|
|
||||||
|
model_name = model_ref.split("/")[1]
|
||||||
|
|
||||||
|
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||||
|
os.environ["MASTER_PORT"] = "8080"
|
||||||
|
|
||||||
|
torch.distributed.init_process_group(world_size=1, rank=0)
|
||||||
|
initialize_model_parallel()
|
||||||
|
|
||||||
|
keyfile = args.keyfile if args.keyfile else None
|
||||||
|
|
||||||
|
if args.command == "serialize":
|
||||||
|
input_dir = args.serialized_directory.rstrip('/')
|
||||||
|
suffix = args.suffix if args.suffix else uuid.uuid4().hex
|
||||||
|
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
|
||||||
|
model_path = f"{base_path}/model.tensors"
|
||||||
|
serialize()
|
||||||
|
elif args.command == "deserialize":
|
||||||
|
tensorizer_args = TensorizerArgs.from_cli_args(args)
|
||||||
|
model_path = args.path_to_tensors
|
||||||
|
deserialize()
|
||||||
|
else:
|
||||||
|
raise ValueError("Either serialize or deserialize must be specified.")
|
||||||
20
format.sh
20
format.sh
@@ -93,9 +93,21 @@ fi
|
|||||||
echo 'vLLM yapf: Done'
|
echo 'vLLM yapf: Done'
|
||||||
|
|
||||||
# Run mypy
|
# Run mypy
|
||||||
# TODO(zhuohan): Enable mypy
|
echo 'vLLM mypy:'
|
||||||
# echo 'vLLM mypy:'
|
mypy vllm/attention --config-file pyproject.toml
|
||||||
# mypy
|
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
mypy vllm/distributed --config-file pyproject.toml
|
||||||
|
mypy vllm/entrypoints --config-file pyproject.toml
|
||||||
|
mypy vllm/executor --config-file pyproject.toml
|
||||||
|
mypy vllm/usage --config-file pyproject.toml
|
||||||
|
mypy vllm/*.py --config-file pyproject.toml
|
||||||
|
mypy vllm/transformers_utils --config-file pyproject.toml
|
||||||
|
mypy vllm/engine --config-file pyproject.toml
|
||||||
|
mypy vllm/worker --config-file pyproject.toml
|
||||||
|
mypy vllm/spec_decode --config-file pyproject.toml
|
||||||
|
mypy vllm/model_executor/*.py --config-file pyproject.toml
|
||||||
|
# mypy vllm/lora/*.py --config-file pyproject.toml
|
||||||
|
|
||||||
|
|
||||||
CODESPELL_EXCLUDES=(
|
CODESPELL_EXCLUDES=(
|
||||||
'--skip' '*docs/source/_build/**'
|
'--skip' '*docs/source/_build/**'
|
||||||
@@ -228,5 +240,3 @@ if ! git diff --quiet &>/dev/null; then
|
|||||||
|
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,33 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set -e
|
|
||||||
|
|
||||||
XFORMERS_VERSION="0.0.23"
|
|
||||||
|
|
||||||
export XFORMERS_INSTALLED_VERSION=$(python -c 'import xformers; print(xformers.__version__)')
|
|
||||||
|
|
||||||
if [ "$XFORMERS_INSTALLED_VERSION" != "$XFORMERS_VERSION" ]; then
|
|
||||||
echo "ERROR: xformers version must be ${XFORMERS_VERSION}. ${XFORMERS_INSTALLED_VERSION} is installed"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)')
|
|
||||||
export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)')
|
|
||||||
|
|
||||||
echo "XFORMERS_FMHA_FLASH_PATH = ${XFORMERS_FMHA_FLASH_PATH}"
|
|
||||||
echo "XFORMERS_FMHA_COMMON_PATH = ${XFORMERS_FMHA_COMMON_PATH}"
|
|
||||||
|
|
||||||
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then
|
|
||||||
echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}"
|
|
||||||
patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"
|
|
||||||
echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}"
|
|
||||||
else
|
|
||||||
echo "${XFORMERS_FMHA_FLASH_PATH} was patched before"
|
|
||||||
fi
|
|
||||||
|
|
||||||
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then
|
|
||||||
echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}"
|
|
||||||
patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"
|
|
||||||
echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}"
|
|
||||||
else
|
|
||||||
echo "${XFORMERS_FMHA_COMMON_PATH} was patched before"
|
|
||||||
fi
|
|
||||||
@@ -5,7 +5,7 @@ requires = [
|
|||||||
"ninja",
|
"ninja",
|
||||||
"packaging",
|
"packaging",
|
||||||
"setuptools >= 49.4.0",
|
"setuptools >= 49.4.0",
|
||||||
"torch == 2.1.2",
|
"torch == 2.2.1",
|
||||||
"wheel",
|
"wheel",
|
||||||
]
|
]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
@@ -13,6 +13,10 @@ build-backend = "setuptools.build_meta"
|
|||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
# Allow lines to be as long as 80.
|
# Allow lines to be as long as 80.
|
||||||
line-length = 80
|
line-length = 80
|
||||||
|
exclude = [
|
||||||
|
# External file, leaving license intact
|
||||||
|
"examples/fp8/quantizer/quantize.py"
|
||||||
|
]
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [
|
select = [
|
||||||
@@ -42,11 +46,16 @@ ignore = [
|
|||||||
python_version = "3.8"
|
python_version = "3.8"
|
||||||
|
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
|
check_untyped_defs = true
|
||||||
|
follow_imports = "skip"
|
||||||
|
|
||||||
files = "vllm"
|
files = "vllm"
|
||||||
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
||||||
exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/"
|
exclude = [
|
||||||
|
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
|
||||||
|
# Ignore triton kernels in ops.
|
||||||
|
'vllm/attention/ops/.*\.py$'
|
||||||
|
]
|
||||||
|
|
||||||
[tool.codespell]
|
[tool.codespell]
|
||||||
ignore-words-list = "dout, te, indicies"
|
ignore-words-list = "dout, te, indicies"
|
||||||
|
|||||||
@@ -3,5 +3,5 @@ cmake>=3.21
|
|||||||
ninja
|
ninja
|
||||||
packaging
|
packaging
|
||||||
setuptools>=49.4.0
|
setuptools>=49.4.0
|
||||||
torch==2.1.2
|
torch==2.2.1
|
||||||
wheel
|
wheel
|
||||||
|
|||||||
18
requirements-common.txt
Normal file
18
requirements-common.txt
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
cmake >= 3.21
|
||||||
|
ninja # For faster builds.
|
||||||
|
psutil
|
||||||
|
sentencepiece # Required for LLaMA tokenizer.
|
||||||
|
numpy
|
||||||
|
requests
|
||||||
|
py-cpuinfo
|
||||||
|
transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3.
|
||||||
|
tokenizers >= 0.19.1 # Required for Llama 3.
|
||||||
|
fastapi
|
||||||
|
uvicorn[standard]
|
||||||
|
pydantic >= 2.0 # Required for OpenAI server.
|
||||||
|
prometheus_client >= 0.18.0
|
||||||
|
tiktoken == 0.6.0 # Required for DBRX tokenizer
|
||||||
|
lm-format-enforcer == 0.9.8
|
||||||
|
outlines == 0.0.34 # Requires torch >= 2.1.0
|
||||||
|
typing_extensions
|
||||||
|
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
|
||||||
6
requirements-cpu.txt
Normal file
6
requirements-cpu.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
# Common dependencies
|
||||||
|
-r requirements-common.txt
|
||||||
|
|
||||||
|
# Dependencies for x86_64 CPUs
|
||||||
|
torch == 2.2.1+cpu
|
||||||
|
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
|
||||||
9
requirements-cuda.txt
Normal file
9
requirements-cuda.txt
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
# Common dependencies
|
||||||
|
-r requirements-common.txt
|
||||||
|
|
||||||
|
# Dependencies for NVIDIA GPUs
|
||||||
|
ray >= 2.9
|
||||||
|
nvidia-ml-py # for pynvml package
|
||||||
|
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
|
||||||
|
torch == 2.2.1
|
||||||
|
xformers == 0.0.25 # Requires PyTorch 2.2.1
|
||||||
@@ -7,13 +7,14 @@ codespell==2.2.6
|
|||||||
isort==5.13.2
|
isort==5.13.2
|
||||||
|
|
||||||
# type checking
|
# type checking
|
||||||
mypy==0.991
|
mypy==1.9.0
|
||||||
types-PyYAML
|
types-PyYAML
|
||||||
types-requests
|
types-requests
|
||||||
types-setuptools
|
types-setuptools
|
||||||
|
|
||||||
# testing
|
# testing
|
||||||
pytest
|
pytest
|
||||||
|
tensorizer==2.9.0a0
|
||||||
pytest-forked
|
pytest-forked
|
||||||
pytest-asyncio
|
pytest-asyncio
|
||||||
pytest-rerunfailures
|
pytest-rerunfailures
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user