Compare commits
396 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
82091b864a | ||
|
|
c0c2335ce0 | ||
|
|
90fbf12540 | ||
|
|
49d849b3ab | ||
|
|
27ca23dc00 | ||
|
|
54d3544784 | ||
|
|
703e42ee4b | ||
|
|
29a8d6a554 | ||
|
|
2c08ff23c0 | ||
|
|
bfdcfa6a05 | ||
|
|
9289e577ec | ||
|
|
a6d471c759 | ||
|
|
01a5d18a53 | ||
|
|
929b4f2973 | ||
|
|
3b7178cfa4 | ||
|
|
e46fa5d52e | ||
|
|
a8683102cc | ||
|
|
71bcaf99e2 | ||
|
|
8b430d7dea | ||
|
|
e0ade06d63 | ||
|
|
4bd18ec0c7 | ||
|
|
2410e320b3 | ||
|
|
48a8f4a7fd | ||
|
|
4dd6416faf | ||
|
|
c1c0d00b88 | ||
|
|
d9f726c4d0 | ||
|
|
d6e4a130b0 | ||
|
|
cfc15a1031 | ||
|
|
70f3e8e3a1 | ||
|
|
ef978fe411 | ||
|
|
f7c1234990 | ||
|
|
57f044945f | ||
|
|
4caf7044e0 | ||
|
|
6f32cddf1c | ||
|
|
c530e2cfe3 | ||
|
|
fd5dcc5c81 | ||
|
|
93dc5a2870 | ||
|
|
95529e3253 | ||
|
|
344020c926 | ||
|
|
5574081c49 | ||
|
|
d7f396486e | ||
|
|
8fbd84bf78 | ||
|
|
7d2dcce175 | ||
|
|
dc903e70ac | ||
|
|
a9c8212895 | ||
|
|
c20ecb6a51 | ||
|
|
5253edaacb | ||
|
|
017d9f1515 | ||
|
|
181b27d881 | ||
|
|
63e2a6419d | ||
|
|
264017a2bf | ||
|
|
e433c115bc | ||
|
|
86fd8bb0ac | ||
|
|
ab3a5a8259 | ||
|
|
a61f0521b8 | ||
|
|
537c9755a7 | ||
|
|
786b7f18a5 | ||
|
|
8f36444c4f | ||
|
|
185b2c29e2 | ||
|
|
5f08050d8d | ||
|
|
64da65b322 | ||
|
|
5255d99dc5 | ||
|
|
4f2ad11135 | ||
|
|
d7afab6d3a | ||
|
|
31348dff03 | ||
|
|
25e86b6a61 | ||
|
|
4efbac6d35 | ||
|
|
87069ccf68 | ||
|
|
7e45107f51 | ||
|
|
0c48b37c31 | ||
|
|
7eacffd951 | ||
|
|
2a543d6efe | ||
|
|
317b29de0f | ||
|
|
a463c333dd | ||
|
|
ea356004d4 | ||
|
|
5c976a7e1a | ||
|
|
f964493274 | ||
|
|
a4211a4dc3 | ||
|
|
563836496a | ||
|
|
4ca2c358b1 | ||
|
|
0580aab02f | ||
|
|
3711811b1d | ||
|
|
65b89d16ee | ||
|
|
931746bc6d | ||
|
|
c81dddb45c | ||
|
|
fe6d09ae61 | ||
|
|
ed70c70ea3 | ||
|
|
f0d4e14557 | ||
|
|
2ccee3def6 | ||
|
|
b92adec8e8 | ||
|
|
56f738ae9b | ||
|
|
72d3a30c63 | ||
|
|
c9b45adeeb | ||
|
|
5a6c81b051 | ||
|
|
51cd22ce56 | ||
|
|
5ed704ec8c | ||
|
|
4abf6336ec | ||
|
|
0e163fce18 | ||
|
|
96b6f475dd | ||
|
|
c410f5d020 | ||
|
|
bb8c697ee0 | ||
|
|
b9e96b17de | ||
|
|
923797fea4 | ||
|
|
cd9e60c76c | ||
|
|
93b38bea5d | ||
|
|
d0d93b92b1 | ||
|
|
89efcf1ce5 | ||
|
|
c664b0e683 | ||
|
|
d69ff0cbbb | ||
|
|
1af090b57d | ||
|
|
3dad944485 | ||
|
|
105a40f53a | ||
|
|
bbe9bd9684 | ||
|
|
4f65af0e25 | ||
|
|
d79ced3292 | ||
|
|
ab40644669 | ||
|
|
5d60def02c | ||
|
|
ea8489fce2 | ||
|
|
1b20639a43 | ||
|
|
b72af8f1ed | ||
|
|
9090bf02e7 | ||
|
|
7d648418b8 | ||
|
|
89be30fa7d | ||
|
|
f8ecb84c02 | ||
|
|
5f036d2bcc | ||
|
|
380170038e | ||
|
|
220a47627b | ||
|
|
beb89f68b4 | ||
|
|
390b495ff3 | ||
|
|
3a0e1fc070 | ||
|
|
6b7de1a030 | ||
|
|
5265631d15 | ||
|
|
2832e7b9f9 | ||
|
|
3a7dd7e367 | ||
|
|
223c19224b | ||
|
|
f1f6cc10c7 | ||
|
|
3209b49033 | ||
|
|
1e4277d2d1 | ||
|
|
9b945daaf1 | ||
|
|
9c1352eb57 | ||
|
|
7a0b011dd5 | ||
|
|
63e835cbcc | ||
|
|
94b5edeb53 | ||
|
|
ab7e6006d6 | ||
|
|
18bfcdd05c | ||
|
|
71d63ed72e | ||
|
|
d75c40734a | ||
|
|
5b23c3f26f | ||
|
|
00efdc84ba | ||
|
|
91a61da9b1 | ||
|
|
ef9b636e2d | ||
|
|
2709c0009a | ||
|
|
dd7e8f5f64 | ||
|
|
d2a68364c4 | ||
|
|
7e1081139d | ||
|
|
18473cf498 | ||
|
|
4df417d059 | ||
|
|
5d80a9178b | ||
|
|
8a25d3a71a | ||
|
|
d10f8e1d43 | ||
|
|
14cc317ba4 | ||
|
|
e1957c6ebd | ||
|
|
8cd5a992bf | ||
|
|
947f0b23cc | ||
|
|
f780504d12 | ||
|
|
bfc072addf | ||
|
|
2a18da257c | ||
|
|
6e01e8c1c8 | ||
|
|
9f659bf07f | ||
|
|
35c4bc20d9 | ||
|
|
218dc2ccda | ||
|
|
827cbcd37c | ||
|
|
cb7a1c1cbf | ||
|
|
7878958c0d | ||
|
|
ce036244c9 | ||
|
|
48cf1e413c | ||
|
|
97460585d9 | ||
|
|
f745847ef7 | ||
|
|
6549aef245 | ||
|
|
50376faa7b | ||
|
|
4b61c6b669 | ||
|
|
79d64c4954 | ||
|
|
74cd5abdd1 | ||
|
|
28c3f12104 | ||
|
|
c884819135 | ||
|
|
05921a9a7a | ||
|
|
d0215a58e7 | ||
|
|
937e7b7d7c | ||
|
|
aee8ef661a | ||
|
|
2e0b6e7757 | ||
|
|
941767127c | ||
|
|
74d8d77626 | ||
|
|
fd4ea8ef5c | ||
|
|
1066cbd152 | ||
|
|
6ef00b03a2 | ||
|
|
9140561059 | ||
|
|
77af974b40 | ||
|
|
4934d49274 | ||
|
|
358c328d69 | ||
|
|
4aaafdd289 | ||
|
|
66b108d142 | ||
|
|
e0ff920001 | ||
|
|
face83c7ec | ||
|
|
1db83e31a2 | ||
|
|
a1b9cb2a34 | ||
|
|
3a4fd5ca59 | ||
|
|
c17daa9f89 | ||
|
|
bd29cf3d3a | ||
|
|
31bff69151 | ||
|
|
ba4f826738 | ||
|
|
de60a3fb93 | ||
|
|
21d5daa4ac | ||
|
|
290e015c6c | ||
|
|
1b7c791d60 | ||
|
|
bbe4466fd9 | ||
|
|
08133c4d1a | ||
|
|
76a7983b23 | ||
|
|
8041b7305e | ||
|
|
3ec8c25cd0 | ||
|
|
671af2b1c0 | ||
|
|
6f41f0e377 | ||
|
|
2c9b638065 | ||
|
|
a7347d9a6d | ||
|
|
f8c688d746 | ||
|
|
c9fadda543 | ||
|
|
30fb0956df | ||
|
|
3a765bd5e1 | ||
|
|
26c52a5ea6 | ||
|
|
c3372e87be | ||
|
|
b0a1d667b0 | ||
|
|
e1d5402238 | ||
|
|
3d1cfbfc74 | ||
|
|
37ca558103 | ||
|
|
eed74a558f | ||
|
|
2acd76f346 | ||
|
|
b81a6a6bb3 | ||
|
|
0fbfc4b81b | ||
|
|
c06170cc8e | ||
|
|
614856da25 | ||
|
|
05bdf4eaf3 | ||
|
|
6774bd50b0 | ||
|
|
31c1f3255e | ||
|
|
21d93c140d | ||
|
|
f1c8520146 | ||
|
|
096827c284 | ||
|
|
6565d9e33e | ||
|
|
f375ec8440 | ||
|
|
518369d78c | ||
|
|
30bad5c492 | ||
|
|
3fefe271ec | ||
|
|
6428f1d051 | ||
|
|
7e1b21daac | ||
|
|
cb3f30c600 | ||
|
|
f3e024bece | ||
|
|
31d2ab4aff | ||
|
|
eb17212858 | ||
|
|
4dd4b5c538 | ||
|
|
6120e5aaea | ||
|
|
2eaa81b236 | ||
|
|
81ce2a4b26 | ||
|
|
5dd80d3777 | ||
|
|
beeee69bc9 | ||
|
|
9bf28d0b69 | ||
|
|
c0ce15dfb2 | ||
|
|
b9bcdc7158 | ||
|
|
4ff0203987 | ||
|
|
b5f882cc98 | ||
|
|
2e8fc0d4c3 | ||
|
|
dacaf5a400 | ||
|
|
24cde76a15 | ||
|
|
1aa1361510 | ||
|
|
fe470ae5ad | ||
|
|
3a8c2381f7 | ||
|
|
c85b80c2b6 | ||
|
|
2b981012a6 | ||
|
|
6ccc0bfffb | ||
|
|
c8e7eb1eb3 | ||
|
|
24f60a54f4 | ||
|
|
42c02f5892 | ||
|
|
ebede26ebf | ||
|
|
d940ce497e | ||
|
|
05ff90b692 | ||
|
|
1d9b737e05 | ||
|
|
60dc62dc9e | ||
|
|
0f90effc66 | ||
|
|
464dd985e3 | ||
|
|
c07a442854 | ||
|
|
cd3aa153a4 | ||
|
|
9b294976a2 | ||
|
|
5313c2cb8b | ||
|
|
5f09cbdb63 | ||
|
|
4cefa9b49b | ||
|
|
f86bd6190a | ||
|
|
e5452ddfd6 | ||
|
|
d06980dfa7 | ||
|
|
66785cc05c | ||
|
|
05a38612b0 | ||
|
|
d27f4bae39 | ||
|
|
8d8c2f6ffe | ||
|
|
51d3cb951d | ||
|
|
e74b1736a1 | ||
|
|
f07c1ceaa5 | ||
|
|
63b2206ad0 | ||
|
|
27feead2f8 | ||
|
|
c782195662 | ||
|
|
0f621c2c7d | ||
|
|
a9e4574261 | ||
|
|
0229c386c5 | ||
|
|
a7b3e33078 | ||
|
|
e19a64c7ef | ||
|
|
1cb4ad8de9 | ||
|
|
6ed068a71a | ||
|
|
708e6c18b0 | ||
|
|
b943890484 | ||
|
|
a1125ad4df | ||
|
|
a8b150c595 | ||
|
|
665cbcec4b | ||
|
|
7c600440f7 | ||
|
|
e0c6f556e8 | ||
|
|
de23687d16 | ||
|
|
4cea74c73b | ||
|
|
a921d8be9d | ||
|
|
094f716bf2 | ||
|
|
7d761fe3c1 | ||
|
|
cf35d8f3d7 | ||
|
|
4bb6b67188 | ||
|
|
819b18e7ba | ||
|
|
19849db573 | ||
|
|
3d4ceb292c | ||
|
|
f5a37c6c6c | ||
|
|
32c927b53f | ||
|
|
5ffc0d13a2 | ||
|
|
112627e8b2 | ||
|
|
37c1e3c218 | ||
|
|
06e9ebebd5 | ||
|
|
c5f7740d89 | ||
|
|
be66d9b125 | ||
|
|
e1054247ba | ||
|
|
8d17774f92 | ||
|
|
e946260cf3 | ||
|
|
edb305584b | ||
|
|
bb00f66e19 | ||
|
|
e87557b069 | ||
|
|
dcc543a298 | ||
|
|
0fc280b06c | ||
|
|
20d0699d49 | ||
|
|
686f5e3210 | ||
|
|
415d109527 | ||
|
|
521b35f799 | ||
|
|
cb08cd0d75 | ||
|
|
2a2c135b41 | ||
|
|
65ea2ddf17 | ||
|
|
b514d3c496 | ||
|
|
7076fa1c9f | ||
|
|
660a7fcfa4 | ||
|
|
054072bee5 | ||
|
|
eb825c1e74 | ||
|
|
1b290ace4f | ||
|
|
0d578228ca | ||
|
|
aebfcb262a | ||
|
|
ab9e8488d5 | ||
|
|
fd58b73a40 | ||
|
|
8efe23f150 | ||
|
|
06458a0b42 | ||
|
|
1a2bbc9301 | ||
|
|
e7f579eb97 | ||
|
|
8516999495 | ||
|
|
9f669a9a7c | ||
|
|
555bdcc5a3 | ||
|
|
54ca1ba71d | ||
|
|
9738b84a08 | ||
|
|
1fe0990023 | ||
|
|
7e90a2d117 | ||
|
|
5687d584fe | ||
|
|
cf8849f2d6 | ||
|
|
e575df33b1 | ||
|
|
0ce8647dc5 | ||
|
|
9cabcb7645 | ||
|
|
7b895c5976 | ||
|
|
7013a80170 | ||
|
|
79a30912b8 | ||
|
|
2f3d36a8a1 | ||
|
|
ac8d36f3e5 | ||
|
|
15f5632365 | ||
|
|
aa9af07cac | ||
|
|
69be658bba | ||
|
|
beac8dd461 | ||
|
|
28b47d1e49 | ||
|
|
1f24755bf8 | ||
|
|
bf31d3606a | ||
|
|
d189170b6c | ||
|
|
f61dc8072f | ||
|
|
f8a1e39fae | ||
|
|
a132435204 | ||
|
|
9524867701 | ||
|
|
c1376e0f82 |
69
.buildkite/run-benchmarks.sh
Normal file
69
.buildkite/run-benchmarks.sh
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
# This script is run by buildkite to run the benchmarks and upload the results to buildkite
|
||||||
|
|
||||||
|
set -ex
|
||||||
|
set -o pipefail
|
||||||
|
|
||||||
|
# cd into parent directory of this file
|
||||||
|
cd "$(dirname "${BASH_SOURCE[0]}")/.."
|
||||||
|
|
||||||
|
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||||
|
|
||||||
|
# run python-based benchmarks and upload the result to buildkite
|
||||||
|
python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt
|
||||||
|
bench_latency_exit_code=$?
|
||||||
|
|
||||||
|
python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt
|
||||||
|
bench_throughput_exit_code=$?
|
||||||
|
|
||||||
|
# run server-based benchmarks and upload the result to buildkite
|
||||||
|
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf &
|
||||||
|
server_pid=$!
|
||||||
|
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||||
|
|
||||||
|
# wait for server to start, timeout after 600 seconds
|
||||||
|
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
|
||||||
|
python3 benchmarks/benchmark_serving.py \
|
||||||
|
--backend openai \
|
||||||
|
--dataset ./ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||||
|
--model meta-llama/Llama-2-7b-chat-hf \
|
||||||
|
--num-prompts 20 \
|
||||||
|
--endpoint /v1/completions \
|
||||||
|
--tokenizer meta-llama/Llama-2-7b-chat-hf \
|
||||||
|
--save-result \
|
||||||
|
2>&1 | tee benchmark_serving.txt
|
||||||
|
bench_serving_exit_code=$?
|
||||||
|
kill $server_pid
|
||||||
|
|
||||||
|
# write the results into a markdown file
|
||||||
|
echo "### Latency Benchmarks" >> benchmark_results.md
|
||||||
|
sed -n '1p' benchmark_latency.txt >> benchmark_results.md # first line
|
||||||
|
echo "" >> benchmark_results.md
|
||||||
|
sed -n '$p' benchmark_latency.txt >> benchmark_results.md # last line
|
||||||
|
|
||||||
|
echo "### Throughput Benchmarks" >> benchmark_results.md
|
||||||
|
sed -n '1p' benchmark_throughput.txt >> benchmark_results.md # first line
|
||||||
|
echo "" >> benchmark_results.md
|
||||||
|
sed -n '$p' benchmark_throughput.txt >> benchmark_results.md # last line
|
||||||
|
|
||||||
|
echo "### Serving Benchmarks" >> benchmark_results.md
|
||||||
|
sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line
|
||||||
|
echo "" >> benchmark_results.md
|
||||||
|
tail -n 13 benchmark_serving.txt >> benchmark_results.md # last 13 lines
|
||||||
|
|
||||||
|
# upload the results to buildkite
|
||||||
|
/workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md
|
||||||
|
|
||||||
|
# exit with the exit code of the benchmarks
|
||||||
|
if [ $bench_latency_exit_code -ne 0 ]; then
|
||||||
|
exit $bench_latency_exit_code
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $bench_throughput_exit_code -ne 0 ]; then
|
||||||
|
exit $bench_throughput_exit_code
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $bench_serving_exit_code -ne 0 ]; then
|
||||||
|
exit $bench_serving_exit_code
|
||||||
|
fi
|
||||||
|
|
||||||
|
/workspace/buildkite-agent artifact upload openai-*.json
|
||||||
69
.buildkite/test-pipeline.yaml
Normal file
69
.buildkite/test-pipeline.yaml
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
# In this file, you can add more tests to run either by adding a new step or
|
||||||
|
# adding a new command to an existing step. See different options here for examples.
|
||||||
|
# This script will be feed into Jinja template in `test-template.j2` to generate
|
||||||
|
# the final pipeline yaml file.
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- label: Regression Test
|
||||||
|
command: pytest -v -s test_regression.py
|
||||||
|
working_dir: "/vllm-workspace/tests" # optional
|
||||||
|
|
||||||
|
- label: AsyncEngine Test
|
||||||
|
command: pytest -v -s async_engine
|
||||||
|
|
||||||
|
- label: Basic Correctness Test
|
||||||
|
command: pytest -v -s --forked basic_correctness
|
||||||
|
|
||||||
|
- label: Distributed Comm Ops Test
|
||||||
|
command: pytest -v -s --forked test_comm_ops.py
|
||||||
|
working_dir: "/vllm-workspace/tests/distributed"
|
||||||
|
num_gpus: 2 # only support 1 or 2 for now.
|
||||||
|
|
||||||
|
- label: Distributed Correctness Test
|
||||||
|
command: pytest -v -s --forked test_basic_distributed_correctness.py
|
||||||
|
working_dir: "/vllm-workspace/tests/distributed"
|
||||||
|
num_gpus: 2 # only support 1 or 2 for now.
|
||||||
|
|
||||||
|
- label: Engine Test
|
||||||
|
command: pytest -v -s engine
|
||||||
|
|
||||||
|
- label: Entrypoints Test
|
||||||
|
command: pytest -v -s entrypoints
|
||||||
|
|
||||||
|
- label: Kernels Test
|
||||||
|
command: pytest -v -s kernels
|
||||||
|
soft_fail: true
|
||||||
|
|
||||||
|
- label: Models Test
|
||||||
|
commands:
|
||||||
|
- pytest -v -s models --forked
|
||||||
|
soft_fail: true
|
||||||
|
|
||||||
|
- label: Prefix Caching Test
|
||||||
|
commands:
|
||||||
|
- pytest -v -s prefix_caching
|
||||||
|
|
||||||
|
- label: Samplers Test
|
||||||
|
command: pytest -v -s samplers --forked
|
||||||
|
|
||||||
|
- label: Worker Test
|
||||||
|
command: pytest -v -s worker
|
||||||
|
|
||||||
|
- label: LoRA Test
|
||||||
|
command: pytest -v -s lora --forked
|
||||||
|
|
||||||
|
- label: Metrics Test
|
||||||
|
command: pytest -v -s metrics
|
||||||
|
|
||||||
|
- label: Benchmarks
|
||||||
|
working_dir: "/vllm-workspace/.buildkite"
|
||||||
|
commands:
|
||||||
|
- pip install aiohttp
|
||||||
|
- bash run-benchmarks.sh
|
||||||
|
|
||||||
|
- label: Documentation Build
|
||||||
|
working_dir: "/vllm-workspace/docs"
|
||||||
|
no_gpu: True
|
||||||
|
commands:
|
||||||
|
- pip install -r requirements-docs.txt
|
||||||
|
- SPHINXOPTS=\"-W\" make html
|
||||||
56
.buildkite/test-template.j2
Normal file
56
.buildkite/test-template.j2
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
{% set docker_image = "us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:$BUILDKITE_COMMIT" %}
|
||||||
|
{% set default_num_gpu = 1 %}
|
||||||
|
{% set default_working_dir = "/vllm-workspace/tests" %}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- label: ":docker: build image"
|
||||||
|
commands:
|
||||||
|
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
|
||||||
|
- "docker push {{ docker_image }}"
|
||||||
|
env:
|
||||||
|
DOCKER_BUILDKIT: "1"
|
||||||
|
retry:
|
||||||
|
automatic:
|
||||||
|
- exit_status: -1 # Agent was lost
|
||||||
|
limit: 5
|
||||||
|
- wait
|
||||||
|
|
||||||
|
{% for step in steps %}
|
||||||
|
- label: "{{ step.label }}"
|
||||||
|
agents:
|
||||||
|
queue: kubernetes
|
||||||
|
soft_fail: {{ step.soft_fail or false }}
|
||||||
|
retry:
|
||||||
|
automatic:
|
||||||
|
- exit_status: -1 # Agent was lost
|
||||||
|
limit: 5
|
||||||
|
plugins:
|
||||||
|
- kubernetes:
|
||||||
|
podSpec:
|
||||||
|
volumes:
|
||||||
|
- name: dshm
|
||||||
|
emptyDir:
|
||||||
|
medium: Memory
|
||||||
|
containers:
|
||||||
|
- image: "{{ docker_image }}"
|
||||||
|
command: ["bash"]
|
||||||
|
args:
|
||||||
|
- '-c'
|
||||||
|
- "'cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}'"
|
||||||
|
{% if not step.no_gpu %}
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
|
||||||
|
limits:
|
||||||
|
nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
|
||||||
|
{% endif %}
|
||||||
|
env:
|
||||||
|
- name: HF_TOKEN
|
||||||
|
valueFrom:
|
||||||
|
secretKeyRef:
|
||||||
|
name: hf-token-secret
|
||||||
|
key: token
|
||||||
|
volumeMounts:
|
||||||
|
- mountPath: /dev/shm
|
||||||
|
name: dshm
|
||||||
|
{% endfor %}
|
||||||
1
.dockerignore
Normal file
1
.dockerignore
Normal file
@@ -0,0 +1 @@
|
|||||||
|
vllm/*.so
|
||||||
4
.github/workflows/publish.yml
vendored
4
.github/workflows/publish.yml
vendored
@@ -49,8 +49,8 @@ jobs:
|
|||||||
matrix:
|
matrix:
|
||||||
os: ['ubuntu-20.04']
|
os: ['ubuntu-20.04']
|
||||||
python-version: ['3.8', '3.9', '3.10', '3.11']
|
python-version: ['3.8', '3.9', '3.10', '3.11']
|
||||||
pytorch-version: ['2.0.1']
|
pytorch-version: ['2.1.2'] # Must be the most recent version that meets requirements.txt.
|
||||||
cuda-version: ['11.8'] # Github runner can't build anything older than 11.8
|
cuda-version: ['11.8', '12.1']
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
name: pylint
|
name: ruff
|
||||||
|
|
||||||
on:
|
on:
|
||||||
# Trigger the workflow on push or pull request,
|
# Trigger the workflow on push or pull request,
|
||||||
@@ -11,7 +11,7 @@ on:
|
|||||||
- main
|
- main
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pylint:
|
ruff:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@@ -25,7 +25,10 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install pylint==2.8.2
|
pip install ruff==0.1.5 codespell==2.2.6 tomli==2.0.1
|
||||||
- name: Analysing the code with pylint
|
- name: Analysing the code with ruff
|
||||||
run: |
|
run: |
|
||||||
pylint vllm tests
|
ruff vllm tests
|
||||||
|
- name: Spelling check with codespell
|
||||||
|
run: |
|
||||||
|
codespell --toml pyproject.toml
|
||||||
5
.github/workflows/scripts/build.sh
vendored
5
.github/workflows/scripts/build.sh
vendored
@@ -11,5 +11,10 @@ LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
|
|||||||
$python_executable -m pip install wheel packaging
|
$python_executable -m pip install wheel packaging
|
||||||
$python_executable -m pip install -r requirements.txt
|
$python_executable -m pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Limit the number of parallel jobs to avoid OOM
|
||||||
|
export MAX_JOBS=1
|
||||||
|
# Make sure punica is built for the release (for LoRA)
|
||||||
|
export VLLM_INSTALL_PUNICA_KERNELS=1
|
||||||
|
|
||||||
# Build
|
# Build
|
||||||
$python_executable setup.py bdist_wheel --dist-dir=dist
|
$python_executable setup.py bdist_wheel --dist-dir=dist
|
||||||
|
|||||||
5
.github/workflows/scripts/cuda-install.sh
vendored
5
.github/workflows/scripts/cuda-install.sh
vendored
@@ -16,3 +16,8 @@ sudo apt clean
|
|||||||
# Test nvcc
|
# Test nvcc
|
||||||
PATH=/usr/local/cuda-$1/bin:${PATH}
|
PATH=/usr/local/cuda-$1/bin:${PATH}
|
||||||
nvcc --version
|
nvcc --version
|
||||||
|
|
||||||
|
# Log gcc, g++, c++ versions
|
||||||
|
gcc --version
|
||||||
|
g++ --version
|
||||||
|
c++ --version
|
||||||
|
|||||||
2
.github/workflows/yapf.yml
vendored
2
.github/workflows/yapf.yml
vendored
@@ -28,4 +28,4 @@ jobs:
|
|||||||
pip install toml==0.10.2
|
pip install toml==0.10.2
|
||||||
- name: Running yapf
|
- name: Running yapf
|
||||||
run: |
|
run: |
|
||||||
yapf --diff --recursive vllm tests
|
yapf --diff --recursive .
|
||||||
|
|||||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -177,3 +177,10 @@ _build/
|
|||||||
# vim swap files
|
# vim swap files
|
||||||
*.swo
|
*.swo
|
||||||
*.swp
|
*.swp
|
||||||
|
|
||||||
|
# hip files generated by PyTorch
|
||||||
|
*.hip
|
||||||
|
*_hip*
|
||||||
|
|
||||||
|
# Benchmark dataset
|
||||||
|
*.json
|
||||||
|
|||||||
434
.pylintrc
434
.pylintrc
@@ -1,434 +0,0 @@
|
|||||||
# This Pylint rcfile contains a best-effort configuration to uphold the
|
|
||||||
# best-practices and style described in the Google Python style guide:
|
|
||||||
# https://google.github.io/styleguide/pyguide.html
|
|
||||||
#
|
|
||||||
# Its canonical open-source location is:
|
|
||||||
# https://google.github.io/styleguide/pylintrc
|
|
||||||
|
|
||||||
[MASTER]
|
|
||||||
|
|
||||||
# Files or directories to be skipped. They should be base names, not paths.
|
|
||||||
ignore=docs
|
|
||||||
|
|
||||||
# Files or directories matching the regex patterns are skipped. The regex
|
|
||||||
# matches against base names, not paths.
|
|
||||||
ignore-patterns=
|
|
||||||
|
|
||||||
# Pickle collected data for later comparisons.
|
|
||||||
persistent=no
|
|
||||||
|
|
||||||
# List of plugins (as comma separated values of python modules names) to load,
|
|
||||||
# usually to register additional checkers.
|
|
||||||
load-plugins=
|
|
||||||
|
|
||||||
# Use multiple processes to speed up Pylint.
|
|
||||||
jobs=4
|
|
||||||
|
|
||||||
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
|
||||||
# active Python interpreter and may run arbitrary code.
|
|
||||||
unsafe-load-any-extension=no
|
|
||||||
|
|
||||||
|
|
||||||
[MESSAGES CONTROL]
|
|
||||||
|
|
||||||
# Only show warnings with the listed confidence levels. Leave empty to show
|
|
||||||
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
|
|
||||||
confidence=
|
|
||||||
|
|
||||||
# Enable the message, report, category or checker with the given id(s). You can
|
|
||||||
# either give multiple identifier separated by comma (,) or put this option
|
|
||||||
# multiple time (only on the command line, not in the configuration file where
|
|
||||||
# it should appear only once). See also the "--disable" option for examples.
|
|
||||||
#enable=
|
|
||||||
|
|
||||||
# Disable the message, report, category or checker with the given id(s). You
|
|
||||||
# can either give multiple identifiers separated by comma (,) or put this
|
|
||||||
# option multiple times (only on the command line, not in the configuration
|
|
||||||
# file where it should appear only once).You can also use "--disable=all" to
|
|
||||||
# disable everything first and then reenable specific checks. For example, if
|
|
||||||
# you want to run only the similarities checker, you can use "--disable=all
|
|
||||||
# --enable=similarities". If you want to run only the classes checker, but have
|
|
||||||
# no Warning level messages displayed, use"--disable=all --enable=classes
|
|
||||||
# --disable=W"
|
|
||||||
disable=abstract-method,
|
|
||||||
apply-builtin,
|
|
||||||
arguments-differ,
|
|
||||||
attribute-defined-outside-init,
|
|
||||||
backtick,
|
|
||||||
bad-option-value,
|
|
||||||
basestring-builtin,
|
|
||||||
buffer-builtin,
|
|
||||||
c-extension-no-member,
|
|
||||||
consider-using-enumerate,
|
|
||||||
cmp-builtin,
|
|
||||||
cmp-method,
|
|
||||||
coerce-builtin,
|
|
||||||
coerce-method,
|
|
||||||
delslice-method,
|
|
||||||
div-method,
|
|
||||||
duplicate-code,
|
|
||||||
eq-without-hash,
|
|
||||||
execfile-builtin,
|
|
||||||
file-builtin,
|
|
||||||
filter-builtin-not-iterating,
|
|
||||||
fixme,
|
|
||||||
getslice-method,
|
|
||||||
global-statement,
|
|
||||||
hex-method,
|
|
||||||
idiv-method,
|
|
||||||
implicit-str-concat-in-sequence,
|
|
||||||
import-error,
|
|
||||||
import-self,
|
|
||||||
import-star-module-level,
|
|
||||||
inconsistent-return-statements,
|
|
||||||
input-builtin,
|
|
||||||
intern-builtin,
|
|
||||||
invalid-str-codec,
|
|
||||||
locally-disabled,
|
|
||||||
logging-fstring-interpolation, # added by vLLM
|
|
||||||
logging-not-lazy, # added by vLLM
|
|
||||||
long-builtin,
|
|
||||||
long-suffix,
|
|
||||||
map-builtin-not-iterating,
|
|
||||||
misplaced-comparison-constant,
|
|
||||||
missing-class-docstring, # TODO (vLLM): enable
|
|
||||||
missing-function-docstring,
|
|
||||||
missing-module-docstring, # TODO (vLLM): enable
|
|
||||||
metaclass-assignment,
|
|
||||||
next-method-called,
|
|
||||||
next-method-defined,
|
|
||||||
no-absolute-import,
|
|
||||||
no-else-break,
|
|
||||||
no-else-continue,
|
|
||||||
no-else-raise,
|
|
||||||
no-else-return,
|
|
||||||
no-init, # added
|
|
||||||
no-member,
|
|
||||||
no-name-in-module,
|
|
||||||
no-self-use,
|
|
||||||
nonzero-method,
|
|
||||||
oct-method,
|
|
||||||
old-division,
|
|
||||||
old-ne-operator,
|
|
||||||
old-octal-literal,
|
|
||||||
old-raise-syntax,
|
|
||||||
parameter-unpacking,
|
|
||||||
print-statement,
|
|
||||||
raising-string,
|
|
||||||
range-builtin-not-iterating,
|
|
||||||
raw_input-builtin,
|
|
||||||
rdiv-method,
|
|
||||||
reduce-builtin,
|
|
||||||
relative-import,
|
|
||||||
reload-builtin,
|
|
||||||
round-builtin,
|
|
||||||
setslice-method,
|
|
||||||
signature-differs,
|
|
||||||
standarderror-builtin,
|
|
||||||
suppressed-message,
|
|
||||||
sys-max-int,
|
|
||||||
too-few-public-methods,
|
|
||||||
too-many-ancestors,
|
|
||||||
too-many-arguments,
|
|
||||||
too-many-boolean-expressions,
|
|
||||||
too-many-branches,
|
|
||||||
too-many-instance-attributes,
|
|
||||||
too-many-locals,
|
|
||||||
too-many-nested-blocks,
|
|
||||||
too-many-public-methods,
|
|
||||||
too-many-return-statements,
|
|
||||||
too-many-statements,
|
|
||||||
trailing-newlines,
|
|
||||||
unichr-builtin,
|
|
||||||
unicode-builtin,
|
|
||||||
unnecessary-pass,
|
|
||||||
unpacking-in-except,
|
|
||||||
unspecified-encoding,
|
|
||||||
useless-else-on-loop,
|
|
||||||
useless-object-inheritance,
|
|
||||||
useless-suppression,
|
|
||||||
using-cmp-argument,
|
|
||||||
wrong-import-order,
|
|
||||||
xrange-builtin,
|
|
||||||
zip-builtin-not-iterating,
|
|
||||||
|
|
||||||
|
|
||||||
[REPORTS]
|
|
||||||
|
|
||||||
# Set the output format. Available formats are text, parseable, colorized, msvs
|
|
||||||
# (visual studio) and html. You can also give a reporter class, eg
|
|
||||||
# mypackage.mymodule.MyReporterClass.
|
|
||||||
output-format=text
|
|
||||||
|
|
||||||
# Tells whether to display a full report or only the messages
|
|
||||||
reports=no
|
|
||||||
|
|
||||||
# Python expression which should return a note less than 10 (10 is the highest
|
|
||||||
# note). You have access to the variables errors warning, statement which
|
|
||||||
# respectively contain the number of errors / warnings messages and the total
|
|
||||||
# number of statements analyzed. This is used by the global evaluation report
|
|
||||||
# (RP0004).
|
|
||||||
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
|
|
||||||
|
|
||||||
# Template used to display messages. This is a python new-style format string
|
|
||||||
# used to format the message information. See doc for all details
|
|
||||||
#msg-template=
|
|
||||||
|
|
||||||
|
|
||||||
[BASIC]
|
|
||||||
|
|
||||||
# Good variable names which should always be accepted, separated by a comma
|
|
||||||
good-names=main,_
|
|
||||||
|
|
||||||
# Bad variable names which should always be refused, separated by a comma
|
|
||||||
bad-names=
|
|
||||||
|
|
||||||
# Colon-delimited sets of names that determine each other's naming style when
|
|
||||||
# the name regexes allow several styles.
|
|
||||||
name-group=
|
|
||||||
|
|
||||||
# Include a hint for the correct naming format with invalid-name
|
|
||||||
include-naming-hint=no
|
|
||||||
|
|
||||||
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
|
||||||
# to this list to register other decorators that produce valid properties.
|
|
||||||
property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
|
|
||||||
|
|
||||||
# Regular expression matching correct function names
|
|
||||||
function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
|
|
||||||
|
|
||||||
# Regular expression matching correct variable names
|
|
||||||
variable-rgx=^[a-z][a-z0-9_]*$
|
|
||||||
|
|
||||||
# Regular expression matching correct constant names
|
|
||||||
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
|
||||||
|
|
||||||
# Regular expression matching correct attribute names
|
|
||||||
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
|
|
||||||
|
|
||||||
# Regular expression matching correct argument names
|
|
||||||
argument-rgx=^[a-z][a-z0-9_]*$
|
|
||||||
|
|
||||||
# Regular expression matching correct class attribute names
|
|
||||||
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
|
||||||
|
|
||||||
# Regular expression matching correct inline iteration names
|
|
||||||
inlinevar-rgx=^[a-z][a-z0-9_]*$
|
|
||||||
|
|
||||||
# Regular expression matching correct class names
|
|
||||||
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
|
|
||||||
|
|
||||||
# Regular expression matching correct module names
|
|
||||||
module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
|
|
||||||
|
|
||||||
# Regular expression matching correct method names
|
|
||||||
method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
|
|
||||||
|
|
||||||
# Regular expression which should only match function or class names that do
|
|
||||||
# not require a docstring.
|
|
||||||
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
|
|
||||||
|
|
||||||
# Minimum line length for functions/classes that require docstrings, shorter
|
|
||||||
# ones are exempt.
|
|
||||||
docstring-min-length=10
|
|
||||||
|
|
||||||
|
|
||||||
[TYPECHECK]
|
|
||||||
|
|
||||||
# List of decorators that produce context managers, such as
|
|
||||||
# contextlib.contextmanager. Add to this list to register other decorators that
|
|
||||||
# produce valid context managers.
|
|
||||||
contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
|
|
||||||
|
|
||||||
# Tells whether missing members accessed in mixin class should be ignored. A
|
|
||||||
# mixin class is detected if its name ends with "mixin" (case insensitive).
|
|
||||||
ignore-mixin-members=yes
|
|
||||||
|
|
||||||
# List of module names for which member attributes should not be checked
|
|
||||||
# (useful for modules/projects where namespaces are manipulated during runtime
|
|
||||||
# and thus existing member attributes cannot be deduced by static analysis. It
|
|
||||||
# supports qualified module names, as well as Unix pattern matching.
|
|
||||||
ignored-modules=
|
|
||||||
|
|
||||||
# List of class names for which member attributes should not be checked (useful
|
|
||||||
# for classes with dynamically set attributes). This supports the use of
|
|
||||||
# qualified names.
|
|
||||||
ignored-classes=optparse.Values,thread._local,_thread._local
|
|
||||||
|
|
||||||
# List of members which are set dynamically and missed by pylint inference
|
|
||||||
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
|
||||||
# expressions are accepted.
|
|
||||||
generated-members=
|
|
||||||
|
|
||||||
|
|
||||||
[FORMAT]
|
|
||||||
|
|
||||||
# Maximum number of characters on a single line.
|
|
||||||
max-line-length=80
|
|
||||||
|
|
||||||
# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
|
|
||||||
# lines made too long by directives to pytype.
|
|
||||||
|
|
||||||
# Regexp for a line that is allowed to be longer than the limit.
|
|
||||||
ignore-long-lines=(?x)(
|
|
||||||
^\s*(\#\ )?<?https?://\S+>?$|
|
|
||||||
^\s*(from\s+\S+\s+)?import\s+.+$)
|
|
||||||
|
|
||||||
# Allow the body of an if to be on the same line as the test if there is no
|
|
||||||
# else.
|
|
||||||
single-line-if-stmt=yes
|
|
||||||
|
|
||||||
# Maximum number of lines in a module
|
|
||||||
max-module-lines=99999
|
|
||||||
|
|
||||||
# String used as indentation unit. The internal Google style guide mandates 2
|
|
||||||
# spaces. Google's externaly-published style guide says 4, consistent with
|
|
||||||
# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
|
|
||||||
# projects (like TensorFlow).
|
|
||||||
indent-string=' '
|
|
||||||
|
|
||||||
# Number of spaces of indent required inside a hanging or continued line.
|
|
||||||
indent-after-paren=4
|
|
||||||
|
|
||||||
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
|
||||||
expected-line-ending-format=
|
|
||||||
|
|
||||||
|
|
||||||
[MISCELLANEOUS]
|
|
||||||
|
|
||||||
# List of note tags to take in consideration, separated by a comma.
|
|
||||||
notes=TODO
|
|
||||||
|
|
||||||
|
|
||||||
[STRING]
|
|
||||||
|
|
||||||
# This flag controls whether inconsistent-quotes generates a warning when the
|
|
||||||
# character used as a quote delimiter is used inconsistently within a module.
|
|
||||||
check-quote-consistency=yes
|
|
||||||
|
|
||||||
|
|
||||||
[VARIABLES]
|
|
||||||
|
|
||||||
# Tells whether we should check for unused import in __init__ files.
|
|
||||||
init-import=no
|
|
||||||
|
|
||||||
# A regular expression matching the name of dummy variables (i.e. expectedly
|
|
||||||
# not used).
|
|
||||||
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
|
|
||||||
|
|
||||||
# List of additional names supposed to be defined in builtins. Remember that
|
|
||||||
# you should avoid to define new builtins when possible.
|
|
||||||
additional-builtins=
|
|
||||||
|
|
||||||
# List of strings which can identify a callback function by name. A callback
|
|
||||||
# name must start or end with one of those strings.
|
|
||||||
callbacks=cb_,_cb
|
|
||||||
|
|
||||||
# List of qualified module names which can have objects that can redefine
|
|
||||||
# builtins.
|
|
||||||
redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
|
|
||||||
|
|
||||||
|
|
||||||
[LOGGING]
|
|
||||||
|
|
||||||
# Logging modules to check that the string format arguments are in logging
|
|
||||||
# function parameter format
|
|
||||||
logging-modules=logging,absl.logging,tensorflow.io.logging
|
|
||||||
|
|
||||||
|
|
||||||
[SIMILARITIES]
|
|
||||||
|
|
||||||
# Minimum lines number of a similarity.
|
|
||||||
min-similarity-lines=4
|
|
||||||
|
|
||||||
# Ignore comments when computing similarities.
|
|
||||||
ignore-comments=yes
|
|
||||||
|
|
||||||
# Ignore docstrings when computing similarities.
|
|
||||||
ignore-docstrings=yes
|
|
||||||
|
|
||||||
# Ignore imports when computing similarities.
|
|
||||||
ignore-imports=no
|
|
||||||
|
|
||||||
|
|
||||||
[SPELLING]
|
|
||||||
|
|
||||||
# Spelling dictionary name. Available dictionaries: none. To make it working
|
|
||||||
# install python-enchant package.
|
|
||||||
spelling-dict=
|
|
||||||
|
|
||||||
# List of comma separated words that should not be checked.
|
|
||||||
spelling-ignore-words=
|
|
||||||
|
|
||||||
# A path to a file that contains private dictionary; one word per line.
|
|
||||||
spelling-private-dict-file=
|
|
||||||
|
|
||||||
# Tells whether to store unknown words to indicated private dictionary in
|
|
||||||
# --spelling-private-dict-file option instead of raising a message.
|
|
||||||
spelling-store-unknown-words=no
|
|
||||||
|
|
||||||
|
|
||||||
[IMPORTS]
|
|
||||||
|
|
||||||
# Deprecated modules which should not be used, separated by a comma
|
|
||||||
deprecated-modules=regsub,
|
|
||||||
TERMIOS,
|
|
||||||
Bastion,
|
|
||||||
rexec,
|
|
||||||
sets
|
|
||||||
|
|
||||||
# Create a graph of every (i.e. internal and external) dependencies in the
|
|
||||||
# given file (report RP0402 must not be disabled)
|
|
||||||
import-graph=
|
|
||||||
|
|
||||||
# Create a graph of external dependencies in the given file (report RP0402 must
|
|
||||||
# not be disabled)
|
|
||||||
ext-import-graph=
|
|
||||||
|
|
||||||
# Create a graph of internal dependencies in the given file (report RP0402 must
|
|
||||||
# not be disabled)
|
|
||||||
int-import-graph=
|
|
||||||
|
|
||||||
# Force import order to recognize a module as part of the standard
|
|
||||||
# compatibility libraries.
|
|
||||||
known-standard-library=
|
|
||||||
|
|
||||||
# Force import order to recognize a module as part of a third party library.
|
|
||||||
known-third-party=enchant, absl
|
|
||||||
|
|
||||||
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
|
||||||
# 3 compatible code, which means that the block might have code that exists
|
|
||||||
# only in one or another interpreter, leading to false positives when analysed.
|
|
||||||
analyse-fallback-blocks=no
|
|
||||||
|
|
||||||
|
|
||||||
[CLASSES]
|
|
||||||
|
|
||||||
# List of method names used to declare (i.e. assign) instance attributes.
|
|
||||||
defining-attr-methods=__init__,
|
|
||||||
__new__,
|
|
||||||
setUp
|
|
||||||
|
|
||||||
# List of member names, which should be excluded from the protected access
|
|
||||||
# warning.
|
|
||||||
exclude-protected=_asdict,
|
|
||||||
_fields,
|
|
||||||
_replace,
|
|
||||||
_source,
|
|
||||||
_make
|
|
||||||
|
|
||||||
# List of valid names for the first argument in a class method.
|
|
||||||
valid-classmethod-first-arg=cls,
|
|
||||||
class_
|
|
||||||
|
|
||||||
# List of valid names for the first argument in a metaclass class method.
|
|
||||||
valid-metaclass-classmethod-first-arg=mcs
|
|
||||||
|
|
||||||
|
|
||||||
[EXCEPTIONS]
|
|
||||||
|
|
||||||
# Exceptions that will emit a warning when being caught. Defaults to
|
|
||||||
# "Exception"
|
|
||||||
overgeneral-exceptions=StandardError,
|
|
||||||
Exception,
|
|
||||||
BaseException
|
|
||||||
105
Dockerfile
Normal file
105
Dockerfile
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
|
||||||
|
# to run the OpenAI compatible server.
|
||||||
|
|
||||||
|
#################### BASE BUILD IMAGE ####################
|
||||||
|
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
|
||||||
|
|
||||||
|
RUN apt-get update -y \
|
||||||
|
&& apt-get install -y python3-pip git
|
||||||
|
|
||||||
|
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||||
|
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||||
|
# this won't be needed for future versions of this docker image
|
||||||
|
# or future versions of triton.
|
||||||
|
RUN ldconfig /usr/local/cuda-12.1/compat/
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
# install build and runtime dependencies
|
||||||
|
COPY requirements.txt requirements.txt
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# install development dependencies
|
||||||
|
COPY requirements-dev.txt requirements-dev.txt
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
|
pip install -r requirements-dev.txt
|
||||||
|
#################### BASE BUILD IMAGE ####################
|
||||||
|
|
||||||
|
|
||||||
|
#################### EXTENSION BUILD IMAGE ####################
|
||||||
|
FROM dev AS build
|
||||||
|
|
||||||
|
# install build dependencies
|
||||||
|
COPY requirements-build.txt requirements-build.txt
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
|
pip install -r requirements-build.txt
|
||||||
|
|
||||||
|
# copy input files
|
||||||
|
COPY csrc csrc
|
||||||
|
COPY setup.py setup.py
|
||||||
|
COPY requirements.txt requirements.txt
|
||||||
|
COPY pyproject.toml pyproject.toml
|
||||||
|
COPY vllm/__init__.py vllm/__init__.py
|
||||||
|
|
||||||
|
# cuda arch list used by torch
|
||||||
|
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
|
||||||
|
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||||
|
# max jobs used by Ninja to build extensions
|
||||||
|
ARG max_jobs=2
|
||||||
|
ENV MAX_JOBS=${max_jobs}
|
||||||
|
# number of threads used by nvcc
|
||||||
|
ARG nvcc_threads=8
|
||||||
|
ENV NVCC_THREADS=$nvcc_threads
|
||||||
|
# make sure punica kernels are built (for LoRA)
|
||||||
|
ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
||||||
|
|
||||||
|
RUN python3 setup.py build_ext --inplace
|
||||||
|
#################### EXTENSION Build IMAGE ####################
|
||||||
|
|
||||||
|
|
||||||
|
#################### TEST IMAGE ####################
|
||||||
|
# image to run unit testing suite
|
||||||
|
FROM dev AS test
|
||||||
|
|
||||||
|
# copy pytorch extensions separately to avoid having to rebuild
|
||||||
|
# when python code changes
|
||||||
|
WORKDIR /vllm-workspace
|
||||||
|
# ADD is used to preserve directory structure
|
||||||
|
ADD . /vllm-workspace/
|
||||||
|
COPY --from=build /workspace/vllm/*.so /vllm-workspace/vllm/
|
||||||
|
# ignore build dependencies installation because we are using pre-complied extensions
|
||||||
|
RUN rm pyproject.toml
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip install . --verbose
|
||||||
|
#################### TEST IMAGE ####################
|
||||||
|
|
||||||
|
|
||||||
|
#################### RUNTIME BASE IMAGE ####################
|
||||||
|
# We used base cuda image because pytorch installs its own cuda libraries.
|
||||||
|
# However cupy depends on cuda libraries so we had to switch to the runtime image
|
||||||
|
# In the future it would be nice to get a container with pytorch and cuda without duplicating cuda
|
||||||
|
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 AS vllm-base
|
||||||
|
|
||||||
|
# libnccl required for ray
|
||||||
|
RUN apt-get update -y \
|
||||||
|
&& apt-get install -y python3-pip
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
COPY requirements.txt requirements.txt
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
|
pip install -r requirements.txt
|
||||||
|
#################### RUNTIME BASE IMAGE ####################
|
||||||
|
|
||||||
|
|
||||||
|
#################### OPENAI API SERVER ####################
|
||||||
|
# openai api server alternative
|
||||||
|
FROM vllm-base AS vllm-openai
|
||||||
|
# install additional dependencies for openai api server
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
|
pip install accelerate
|
||||||
|
|
||||||
|
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
|
||||||
|
COPY vllm vllm
|
||||||
|
|
||||||
|
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
||||||
|
#################### OPENAI API SERVER ####################
|
||||||
95
Dockerfile.rocm
Normal file
95
Dockerfile.rocm
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
# default base image
|
||||||
|
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
||||||
|
|
||||||
|
FROM $BASE_IMAGE
|
||||||
|
|
||||||
|
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
||||||
|
|
||||||
|
RUN echo "Base image is $BASE_IMAGE"
|
||||||
|
|
||||||
|
# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
|
||||||
|
# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
||||||
|
|
||||||
|
|
||||||
|
ARG FA_GFX_ARCHS="gfx90a;gfx942"
|
||||||
|
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
|
||||||
|
|
||||||
|
ARG FA_BRANCH="3d2b6f5"
|
||||||
|
RUN echo "FA_BRANCH is $FA_BRANCH"
|
||||||
|
|
||||||
|
# whether to build flash-attention
|
||||||
|
# if 0, will not build flash attention
|
||||||
|
# this is useful for gfx target where flash-attention is not supported
|
||||||
|
# In that case, we need to use the python reference attention implementation in vllm
|
||||||
|
ARG BUILD_FA="1"
|
||||||
|
|
||||||
|
# Install some basic utilities
|
||||||
|
RUN apt-get update && apt-get install python3 python3-pip -y
|
||||||
|
|
||||||
|
# Install some basic utilities
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
curl \
|
||||||
|
ca-certificates \
|
||||||
|
sudo \
|
||||||
|
git \
|
||||||
|
bzip2 \
|
||||||
|
libx11-6 \
|
||||||
|
build-essential \
|
||||||
|
wget \
|
||||||
|
unzip \
|
||||||
|
nvidia-cuda-toolkit \
|
||||||
|
tmux \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
### Mount Point ###
|
||||||
|
# When launching the container, mount the code directory to /app
|
||||||
|
ARG APP_MOUNT=/app
|
||||||
|
VOLUME [ ${APP_MOUNT} ]
|
||||||
|
WORKDIR ${APP_MOUNT}
|
||||||
|
|
||||||
|
RUN python3 -m pip install --upgrade pip
|
||||||
|
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
|
||||||
|
|
||||||
|
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
|
||||||
|
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
|
||||||
|
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
|
||||||
|
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
|
||||||
|
|
||||||
|
# Install ROCm flash-attention
|
||||||
|
RUN if [ "$BUILD_FA" = "1" ]; then \
|
||||||
|
mkdir libs \
|
||||||
|
&& cd libs \
|
||||||
|
&& git clone https://github.com/ROCm/flash-attention.git \
|
||||||
|
&& cd flash-attention \
|
||||||
|
&& git checkout ${FA_BRANCH} \
|
||||||
|
&& git submodule update --init \
|
||||||
|
&& export GPU_ARCHS=${FA_GFX_ARCHS} \
|
||||||
|
&& if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
|
||||||
|
patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
|
||||||
|
&& python3 setup.py install \
|
||||||
|
&& cd ..; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
COPY ./ /app/vllm
|
||||||
|
|
||||||
|
RUN python3 -m pip install --upgrade pip
|
||||||
|
RUN python3 -m pip install xformers==0.0.23 --no-deps
|
||||||
|
|
||||||
|
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
|
||||||
|
# Manually removed it so that later steps of numpy upgrade can continue
|
||||||
|
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
|
||||||
|
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
|
||||||
|
|
||||||
|
RUN cd /app \
|
||||||
|
&& cd vllm \
|
||||||
|
&& pip install -U -r requirements-rocm.txt \
|
||||||
|
&& if [ "$BUILD_FA" = "1" ]; then \
|
||||||
|
bash patch_xformers.rocm.sh; fi \
|
||||||
|
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \
|
||||||
|
&& python3 setup.py install \
|
||||||
|
&& cd ..
|
||||||
|
|
||||||
|
RUN python3 -m pip install --upgrade pip
|
||||||
|
RUN python3 -m pip install --no-cache-dir ray[all]
|
||||||
|
|
||||||
|
CMD ["/bin/bash"]
|
||||||
26
README.md
26
README.md
@@ -10,13 +10,16 @@ Easy, fast, and cheap LLM serving for everyone
|
|||||||
</h3>
|
</h3>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
| <a href="https://vllm.readthedocs.io/en/latest/"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> |
|
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> |
|
||||||
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
*Latest News* 🔥
|
*Latest News* 🔥
|
||||||
|
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
|
||||||
|
- [2024/01] Added ROCm 6.0 support to vLLM.
|
||||||
|
- [2023/12] Added ROCm 5.7 support to vLLM.
|
||||||
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
|
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
|
||||||
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
|
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
|
||||||
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
|
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
|
||||||
@@ -26,7 +29,7 @@ Easy, fast, and cheap LLM serving for everyone
|
|||||||
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
|
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
|
||||||
|
|
||||||
---
|
---
|
||||||
|
## About
|
||||||
vLLM is a fast and easy-to-use library for LLM inference and serving.
|
vLLM is a fast and easy-to-use library for LLM inference and serving.
|
||||||
|
|
||||||
vLLM is fast with:
|
vLLM is fast with:
|
||||||
@@ -34,6 +37,8 @@ vLLM is fast with:
|
|||||||
- State-of-the-art serving throughput
|
- State-of-the-art serving throughput
|
||||||
- Efficient management of attention key and value memory with **PagedAttention**
|
- Efficient management of attention key and value memory with **PagedAttention**
|
||||||
- Continuous batching of incoming requests
|
- Continuous batching of incoming requests
|
||||||
|
- Fast model execution with CUDA/HIP graph
|
||||||
|
- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629), FP8 KV Cache
|
||||||
- Optimized CUDA kernels
|
- Optimized CUDA kernels
|
||||||
|
|
||||||
vLLM is flexible and easy to use with:
|
vLLM is flexible and easy to use with:
|
||||||
@@ -43,23 +48,38 @@ vLLM is flexible and easy to use with:
|
|||||||
- Tensor parallelism support for distributed inference
|
- Tensor parallelism support for distributed inference
|
||||||
- Streaming outputs
|
- Streaming outputs
|
||||||
- OpenAI-compatible API server
|
- OpenAI-compatible API server
|
||||||
|
- Support NVIDIA GPUs and AMD GPUs
|
||||||
|
- (Experimental) Prefix caching support
|
||||||
|
- (Experimental) Multi-lora support
|
||||||
|
|
||||||
vLLM seamlessly supports many Hugging Face models, including the following architectures:
|
vLLM seamlessly supports many Hugging Face models, including the following architectures:
|
||||||
|
|
||||||
- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
|
- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
|
||||||
- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.)
|
- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
|
||||||
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
|
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
|
||||||
|
- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
|
||||||
|
- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.)
|
||||||
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
|
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
|
||||||
|
- Gemma (`google/gemma-2b`, `google/gemma-7b`, etc.)
|
||||||
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
|
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
|
||||||
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
|
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
|
||||||
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
|
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
|
||||||
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
|
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
|
||||||
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
|
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
|
||||||
|
- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.)
|
||||||
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
||||||
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
||||||
|
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
|
||||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
||||||
|
- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.)
|
||||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
||||||
|
- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.)
|
||||||
|
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
|
||||||
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
||||||
|
- Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.)
|
||||||
|
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
|
||||||
|
- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.)
|
||||||
|
- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
|
||||||
|
|
||||||
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
||||||
|
|
||||||
|
|||||||
284
benchmarks/backend_request_func.py
Normal file
284
benchmarks/backend_request_func.py
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from tqdm.asyncio import tqdm
|
||||||
|
|
||||||
|
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RequestFuncInput:
|
||||||
|
prompt: str
|
||||||
|
api_url: str
|
||||||
|
prompt_len: int
|
||||||
|
output_len: int
|
||||||
|
model: str
|
||||||
|
best_of: int = 1
|
||||||
|
use_beam_search: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RequestFuncOutput:
|
||||||
|
generated_text: str = ""
|
||||||
|
success: bool = False
|
||||||
|
latency: float = 0
|
||||||
|
ttft: float = 0
|
||||||
|
prompt_len: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
async def async_request_tgi(
|
||||||
|
request_func_input: RequestFuncInput,
|
||||||
|
pbar: Optional[tqdm] = None,
|
||||||
|
) -> RequestFuncOutput:
|
||||||
|
api_url = request_func_input.api_url
|
||||||
|
assert api_url.endswith("generate_stream")
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
|
assert not request_func_input.use_beam_search
|
||||||
|
params = {
|
||||||
|
"best_of": request_func_input.best_of,
|
||||||
|
"max_new_tokens": request_func_input.output_len,
|
||||||
|
"do_sample": True,
|
||||||
|
"temperature": 0.01, # TGI does not accept 0.0 temperature.
|
||||||
|
"top_p": 0.99, # TGI does not accept 1.0 top_p.
|
||||||
|
}
|
||||||
|
payload = {
|
||||||
|
"inputs": request_func_input.prompt,
|
||||||
|
"parameters": params,
|
||||||
|
}
|
||||||
|
output = RequestFuncOutput()
|
||||||
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
|
ttft = 0
|
||||||
|
st = time.perf_counter()
|
||||||
|
try:
|
||||||
|
async with session.post(url=api_url, json=payload) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
async for data in response.content.iter_any():
|
||||||
|
if ttft == 0:
|
||||||
|
ttft = time.perf_counter() - st
|
||||||
|
output.ttft = ttft
|
||||||
|
output.latency = time.perf_counter() - st
|
||||||
|
|
||||||
|
body = data.decode("utf-8").lstrip("data:")
|
||||||
|
output.generated_text = json.loads(body)["generated_text"]
|
||||||
|
output.success = True
|
||||||
|
else:
|
||||||
|
output.success = False
|
||||||
|
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
||||||
|
output.success = False
|
||||||
|
|
||||||
|
if pbar:
|
||||||
|
pbar.update(1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
async def async_request_vllm(
|
||||||
|
request_func_input: RequestFuncInput,
|
||||||
|
pbar: Optional[tqdm] = None,
|
||||||
|
) -> RequestFuncOutput:
|
||||||
|
api_url = request_func_input.api_url
|
||||||
|
assert api_url.endswith("generate")
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
|
payload = {
|
||||||
|
"prompt": request_func_input.prompt,
|
||||||
|
"n": 1,
|
||||||
|
"best_of": request_func_input.best_of,
|
||||||
|
"use_beam_search": request_func_input.use_beam_search,
|
||||||
|
"temperature": 0.0 if request_func_input.use_beam_search else 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"max_tokens": request_func_input.output_len,
|
||||||
|
"ignore_eos": True,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
output = RequestFuncOutput()
|
||||||
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
|
ttft = 0
|
||||||
|
st = time.perf_counter()
|
||||||
|
try:
|
||||||
|
async with session.post(url=api_url, json=payload) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
async for data in response.content.iter_any():
|
||||||
|
if ttft == 0:
|
||||||
|
ttft = time.perf_counter() - st
|
||||||
|
output.ttft = ttft
|
||||||
|
output.latency = time.perf_counter() - st
|
||||||
|
|
||||||
|
# When streaming, '\0' is appended to the end of the response.
|
||||||
|
body = data.decode("utf-8").strip("\0")
|
||||||
|
output.generated_text = json.loads(
|
||||||
|
body)["text"][0][len(request_func_input.prompt):]
|
||||||
|
output.success = True
|
||||||
|
|
||||||
|
else:
|
||||||
|
output.success = False
|
||||||
|
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
||||||
|
output.success = False
|
||||||
|
|
||||||
|
if pbar:
|
||||||
|
pbar.update(1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
async def async_request_trt_llm(
|
||||||
|
request_func_input: RequestFuncInput,
|
||||||
|
pbar: Optional[tqdm] = None,
|
||||||
|
) -> RequestFuncOutput:
|
||||||
|
api_url = request_func_input.api_url
|
||||||
|
assert api_url.endswith("generate_stream")
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
|
assert not request_func_input.use_beam_search
|
||||||
|
assert request_func_input.best_of == 1
|
||||||
|
payload = {
|
||||||
|
"accumulate_tokens": True,
|
||||||
|
"text_input": request_func_input.prompt,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"max_tokens": request_func_input.output_len,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
output = RequestFuncOutput()
|
||||||
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
ttft = 0
|
||||||
|
|
||||||
|
st = time.perf_counter()
|
||||||
|
try:
|
||||||
|
async with session.post(url=api_url, json=payload) as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
async for data in resp.content.iter_any():
|
||||||
|
if ttft == 0:
|
||||||
|
ttft = time.perf_counter() - st
|
||||||
|
output.ttft = ttft
|
||||||
|
output.latency = time.perf_counter() - st
|
||||||
|
|
||||||
|
body = data.decode("utf-8").lstrip("data:")
|
||||||
|
output.generated_text = json.loads(body)["text_output"]
|
||||||
|
output.success = True
|
||||||
|
|
||||||
|
else:
|
||||||
|
output.success = False
|
||||||
|
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
||||||
|
output.success = False
|
||||||
|
|
||||||
|
if pbar:
|
||||||
|
pbar.update(1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
async def async_request_deepspeed_mii(
|
||||||
|
request_func_input: RequestFuncInput,
|
||||||
|
pbar: Optional[tqdm] = None,
|
||||||
|
) -> RequestFuncOutput:
|
||||||
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
|
assert request_func_input.best_of == 1
|
||||||
|
assert not request_func_input.use_beam_search
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"prompts": request_func_input.prompt,
|
||||||
|
"max_new_tokens": request_func_input.output_len,
|
||||||
|
"ignore_eos": True,
|
||||||
|
"do_sample": True,
|
||||||
|
"temperature":
|
||||||
|
0.01, # deepspeed-mii does not accept 0.0 temperature.
|
||||||
|
"top_p": 1.0,
|
||||||
|
}
|
||||||
|
output = RequestFuncOutput()
|
||||||
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
|
# DeepSpeed-MII doesn't support streaming as of Jan 28 2024, will use 0 as placeholder.
|
||||||
|
# https://github.com/microsoft/DeepSpeed-MII/pull/311
|
||||||
|
output.ttft = 0
|
||||||
|
|
||||||
|
st = time.perf_counter()
|
||||||
|
try:
|
||||||
|
async with session.post(url=request_func_input.api_url,
|
||||||
|
json=payload) as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
parsed_resp = await resp.json()
|
||||||
|
output.latency = time.perf_counter() - st
|
||||||
|
output.generated_text = parsed_resp[0]["generated_text"]
|
||||||
|
output.success = True
|
||||||
|
else:
|
||||||
|
output.success = False
|
||||||
|
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
||||||
|
output.success = False
|
||||||
|
|
||||||
|
if pbar:
|
||||||
|
pbar.update(1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
async def async_request_openai_completions(
|
||||||
|
request_func_input: RequestFuncInput,
|
||||||
|
pbar: Optional[tqdm] = None,
|
||||||
|
) -> RequestFuncOutput:
|
||||||
|
api_url = request_func_input.api_url
|
||||||
|
assert api_url.endswith("v1/completions")
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
|
assert not request_func_input.use_beam_search
|
||||||
|
payload = {
|
||||||
|
"model": request_func_input.model,
|
||||||
|
"prompt": request_func_input.prompt,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"best_of": request_func_input.best_of,
|
||||||
|
"max_tokens": request_func_input.output_len,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||||
|
}
|
||||||
|
|
||||||
|
output = RequestFuncOutput()
|
||||||
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
|
generated_text = ""
|
||||||
|
ttft = 0
|
||||||
|
st = time.perf_counter()
|
||||||
|
try:
|
||||||
|
async with session.post(url=api_url, json=payload,
|
||||||
|
headers=headers) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
async for chunk in response.content:
|
||||||
|
if ttft == 0:
|
||||||
|
ttft = time.perf_counter() - st
|
||||||
|
output.ttft = ttft
|
||||||
|
|
||||||
|
chunk = chunk.strip()
|
||||||
|
if not chunk:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk = chunk.decode("utf-8").lstrip("data: ")
|
||||||
|
if chunk == "[DONE]":
|
||||||
|
latency = time.perf_counter() - st
|
||||||
|
else:
|
||||||
|
body = json.loads(chunk)
|
||||||
|
generated_text += body["choices"][0]["text"]
|
||||||
|
|
||||||
|
output.generated_text = generated_text
|
||||||
|
output.success = True
|
||||||
|
output.latency = latency
|
||||||
|
else:
|
||||||
|
output.success = False
|
||||||
|
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
||||||
|
output.success = False
|
||||||
|
|
||||||
|
if pbar:
|
||||||
|
pbar.update(1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
ASYNC_REQUEST_FUNCS = {
|
||||||
|
"tgi": async_request_tgi,
|
||||||
|
"vllm": async_request_vllm,
|
||||||
|
"deepspeed-mii": async_request_deepspeed_mii,
|
||||||
|
"openai": async_request_openai_completions,
|
||||||
|
"tensorrt-llm": async_request_trt_llm,
|
||||||
|
}
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
"""Benchmark the latency of processing a single batch of requests."""
|
"""Benchmark the latency of processing a single batch of requests."""
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -12,7 +14,6 @@ from vllm import LLM, SamplingParams
|
|||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
# Process all the requests in a single batch if possible.
|
|
||||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||||
# the engine will automatically process the request in multiple batches.
|
# the engine will automatically process the request in multiple batches.
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
@@ -20,10 +21,11 @@ def main(args: argparse.Namespace):
|
|||||||
tokenizer=args.tokenizer,
|
tokenizer=args.tokenizer,
|
||||||
quantization=args.quantization,
|
quantization=args.quantization,
|
||||||
tensor_parallel_size=args.tensor_parallel_size,
|
tensor_parallel_size=args.tensor_parallel_size,
|
||||||
max_num_seqs=args.batch_size,
|
|
||||||
max_num_batched_tokens=args.batch_size * args.input_len,
|
|
||||||
trust_remote_code=args.trust_remote_code,
|
trust_remote_code=args.trust_remote_code,
|
||||||
dtype=args.dtype,
|
dtype=args.dtype,
|
||||||
|
enforce_eager=args.enforce_eager,
|
||||||
|
kv_cache_dtype=args.kv_cache_dtype,
|
||||||
|
device=args.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
@@ -35,30 +37,50 @@ def main(args: argparse.Namespace):
|
|||||||
max_tokens=args.output_len,
|
max_tokens=args.output_len,
|
||||||
)
|
)
|
||||||
print(sampling_params)
|
print(sampling_params)
|
||||||
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
|
dummy_prompt_token_ids = np.random.randint(10000,
|
||||||
|
size=(args.batch_size,
|
||||||
|
args.input_len))
|
||||||
|
dummy_prompt_token_ids = dummy_prompt_token_ids.tolist()
|
||||||
|
|
||||||
def run_to_completion(profile: bool = False):
|
def run_to_completion(profile_dir: Optional[str] = None):
|
||||||
if profile:
|
if profile_dir:
|
||||||
torch.cuda.cudart().cudaProfilerStart()
|
with torch.profiler.profile(
|
||||||
start_time = time.perf_counter()
|
activities=[
|
||||||
|
torch.profiler.ProfilerActivity.CPU,
|
||||||
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
torch.profiler.ProfilerActivity.CUDA,
|
||||||
sampling_params=sampling_params,
|
],
|
||||||
use_tqdm=False)
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||||
|
str(profile_dir))) as p:
|
||||||
end_time = time.perf_counter()
|
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
||||||
latency = end_time - start_time
|
sampling_params=sampling_params,
|
||||||
if profile:
|
use_tqdm=False)
|
||||||
torch.cuda.cudart().cudaProfilerStop()
|
print(p.key_averages())
|
||||||
return latency
|
else:
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=False)
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
latency = end_time - start_time
|
||||||
|
return latency
|
||||||
|
|
||||||
print("Warming up...")
|
print("Warming up...")
|
||||||
run_to_completion(profile=False)
|
run_to_completion(profile_dir=None)
|
||||||
|
|
||||||
|
if args.profile:
|
||||||
|
profile_dir = args.profile_result_dir
|
||||||
|
if not profile_dir:
|
||||||
|
profile_dir = Path(
|
||||||
|
"."
|
||||||
|
) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
|
||||||
|
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
||||||
|
run_to_completion(profile_dir=profile_dir)
|
||||||
|
return
|
||||||
|
|
||||||
# Benchmark.
|
# Benchmark.
|
||||||
latencies = []
|
latencies = []
|
||||||
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
|
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
|
||||||
latencies.append(run_to_completion(profile=False))
|
latencies.append(run_to_completion(profile_dir=None))
|
||||||
print(f'Avg latency: {np.mean(latencies)} seconds')
|
print(f'Avg latency: {np.mean(latencies)} seconds')
|
||||||
|
|
||||||
|
|
||||||
@@ -70,7 +92,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--tokenizer', type=str, default=None)
|
parser.add_argument('--tokenizer', type=str, default=None)
|
||||||
parser.add_argument('--quantization',
|
parser.add_argument('--quantization',
|
||||||
'-q',
|
'-q',
|
||||||
choices=['awq', None],
|
choices=['awq', 'gptq', 'squeezellm', None],
|
||||||
default=None)
|
default=None)
|
||||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||||
parser.add_argument('--input-len', type=int, default=32)
|
parser.add_argument('--input-len', type=int, default=32)
|
||||||
@@ -97,5 +119,31 @@ if __name__ == '__main__':
|
|||||||
'The "auto" option will use FP16 precision '
|
'The "auto" option will use FP16 precision '
|
||||||
'for FP32 and FP16 models, and BF16 precision '
|
'for FP32 and FP16 models, and BF16 precision '
|
||||||
'for BF16 models.')
|
'for BF16 models.')
|
||||||
|
parser.add_argument('--enforce-eager',
|
||||||
|
action='store_true',
|
||||||
|
help='enforce eager mode and disable CUDA graph')
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-cache-dtype",
|
||||||
|
type=str,
|
||||||
|
choices=['auto', 'fp8_e5m2'],
|
||||||
|
default='auto',
|
||||||
|
help=
|
||||||
|
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--profile',
|
||||||
|
action='store_true',
|
||||||
|
help='profile the generation process of a single batch')
|
||||||
|
parser.add_argument(
|
||||||
|
'--profile-result-dir',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=('path to save the pytorch profiler output. Can be visualized '
|
||||||
|
'with ui.perfetto.dev or Tensorboard.'))
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cuda",
|
||||||
|
choices=["cuda"],
|
||||||
|
help='device type for vLLM execution, supporting CUDA only currently.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ On the server side, run one of the following commands:
|
|||||||
--disable-log-requests
|
--disable-log-requests
|
||||||
|
|
||||||
(TGI backend)
|
(TGI backend)
|
||||||
./launch_hf_server.sh <your_model>
|
./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
|
||||||
|
|
||||||
On the client side, run:
|
On the client side, run:
|
||||||
python benchmarks/benchmark_serving.py \
|
python benchmarks/benchmark_serving.py \
|
||||||
@@ -20,15 +20,36 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
from typing import AsyncGenerator, List, Tuple
|
from typing import AsyncGenerator, List, Tuple
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from tqdm.asyncio import tqdm
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
# (prompt len, output len, latency)
|
from backend_request_func import (
|
||||||
REQUEST_LATENCY: List[Tuple[int, int, float]] = []
|
ASYNC_REQUEST_FUNCS,
|
||||||
|
RequestFuncInput,
|
||||||
|
RequestFuncOutput,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkMetrics:
|
||||||
|
completed: int
|
||||||
|
total_input: int
|
||||||
|
total_output: int
|
||||||
|
request_throughput: float
|
||||||
|
input_throughput: float
|
||||||
|
output_throughput: float
|
||||||
|
mean_ttft_ms: float
|
||||||
|
median_ttft_ms: float
|
||||||
|
p99_ttft_ms: float
|
||||||
|
mean_tpot_ms: float
|
||||||
|
median_tpot_ms: float
|
||||||
|
p99_tpot_ms: float
|
||||||
|
|
||||||
|
|
||||||
def sample_requests(
|
def sample_requests(
|
||||||
@@ -40,15 +61,15 @@ def sample_requests(
|
|||||||
with open(dataset_path) as f:
|
with open(dataset_path) as f:
|
||||||
dataset = json.load(f)
|
dataset = json.load(f)
|
||||||
# Filter out the conversations with less than 2 turns.
|
# Filter out the conversations with less than 2 turns.
|
||||||
dataset = [
|
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||||
data for data in dataset
|
|
||||||
if len(data["conversations"]) >= 2
|
|
||||||
]
|
|
||||||
# Only keep the first two turns of each conversation.
|
# Only keep the first two turns of each conversation.
|
||||||
dataset = [
|
dataset = [(data["conversations"][0]["value"],
|
||||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
data["conversations"][1]["value"]) for data in dataset]
|
||||||
for data in dataset
|
|
||||||
]
|
# some of these will be filtered out, so sample more than we need
|
||||||
|
sampled_indices = random.sample(range(len(dataset)),
|
||||||
|
int(num_requests * 1.2))
|
||||||
|
dataset = [dataset[i] for i in sampled_indices]
|
||||||
|
|
||||||
# Tokenize the prompts and completions.
|
# Tokenize the prompts and completions.
|
||||||
prompts = [prompt for prompt, _ in dataset]
|
prompts = [prompt for prompt, _ in dataset]
|
||||||
@@ -96,79 +117,125 @@ async def get_request(
|
|||||||
await asyncio.sleep(interval)
|
await asyncio.sleep(interval)
|
||||||
|
|
||||||
|
|
||||||
async def send_request(
|
def calculate_metrics(
|
||||||
backend: str,
|
input_requests: List[Tuple[str, int, int]],
|
||||||
api_url: str,
|
outputs: List[RequestFuncOutput],
|
||||||
prompt: str,
|
dur_s: float,
|
||||||
prompt_len: int,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
output_len: int,
|
) -> BenchmarkMetrics:
|
||||||
best_of: int,
|
total_output = 0
|
||||||
use_beam_search: bool,
|
total_input = 0
|
||||||
) -> None:
|
completed = 0
|
||||||
request_start_time = time.perf_counter()
|
per_token_latencies = []
|
||||||
|
ttfts = []
|
||||||
|
for i in range(len(outputs)):
|
||||||
|
if outputs[i].success:
|
||||||
|
output_len = len(tokenizer.encode(outputs[i].generated_text))
|
||||||
|
total_output += output_len
|
||||||
|
total_input += input_requests[i][1]
|
||||||
|
per_token_latencies.append(outputs[i].latency / output_len)
|
||||||
|
ttfts.append(outputs[i].ttft)
|
||||||
|
completed += 1
|
||||||
|
|
||||||
headers = {"User-Agent": "Benchmark Client"}
|
metrics = BenchmarkMetrics(
|
||||||
if backend == "vllm":
|
completed=completed,
|
||||||
pload = {
|
total_input=total_input,
|
||||||
"prompt": prompt,
|
total_output=total_output,
|
||||||
"n": 1,
|
request_throughput=completed / dur_s,
|
||||||
"best_of": best_of,
|
input_throughput=total_input / dur_s,
|
||||||
"use_beam_search": use_beam_search,
|
output_throughput=total_output / dur_s,
|
||||||
"temperature": 0.0 if use_beam_search else 1.0,
|
mean_ttft_ms=np.mean(ttfts) * 1000,
|
||||||
"top_p": 1.0,
|
median_ttft_ms=np.median(ttfts) * 1000,
|
||||||
"max_tokens": output_len,
|
p99_ttft_ms=np.percentile(ttfts, 99) * 1000,
|
||||||
"ignore_eos": True,
|
mean_tpot_ms=np.mean(per_token_latencies) * 1000,
|
||||||
"stream": False,
|
median_tpot_ms=np.median(per_token_latencies) * 1000,
|
||||||
}
|
p99_tpot_ms=np.percentile(per_token_latencies, 99) * 1000,
|
||||||
elif backend == "tgi":
|
)
|
||||||
assert not use_beam_search
|
|
||||||
params = {
|
|
||||||
"best_of": best_of,
|
|
||||||
"max_new_tokens": output_len,
|
|
||||||
"do_sample": True,
|
|
||||||
}
|
|
||||||
pload = {
|
|
||||||
"inputs": prompt,
|
|
||||||
"parameters": params,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown backend: {backend}")
|
|
||||||
|
|
||||||
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
return metrics
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
||||||
while True:
|
|
||||||
async with session.post(api_url, headers=headers, json=pload) as response:
|
|
||||||
chunks = []
|
|
||||||
async for chunk, _ in response.content.iter_chunks():
|
|
||||||
chunks.append(chunk)
|
|
||||||
output = b"".join(chunks).decode("utf-8")
|
|
||||||
output = json.loads(output)
|
|
||||||
|
|
||||||
# Re-send the request if it failed.
|
|
||||||
if "error" not in output:
|
|
||||||
break
|
|
||||||
|
|
||||||
request_end_time = time.perf_counter()
|
|
||||||
request_latency = request_end_time - request_start_time
|
|
||||||
REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
|
|
||||||
|
|
||||||
|
|
||||||
async def benchmark(
|
async def benchmark(
|
||||||
backend: str,
|
backend: str,
|
||||||
api_url: str,
|
api_url: str,
|
||||||
|
model_id: str,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
input_requests: List[Tuple[str, int, int]],
|
input_requests: List[Tuple[str, int, int]],
|
||||||
best_of: int,
|
best_of: int,
|
||||||
use_beam_search: bool,
|
use_beam_search: bool,
|
||||||
request_rate: float,
|
request_rate: float,
|
||||||
) -> None:
|
disable_tqdm: bool,
|
||||||
tasks: List[asyncio.Task] = []
|
):
|
||||||
|
if backend in ASYNC_REQUEST_FUNCS:
|
||||||
|
request_func = ASYNC_REQUEST_FUNCS.get(backend)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown backend: {backend}")
|
||||||
|
|
||||||
|
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
||||||
|
|
||||||
|
print(f"Traffic request rate: {request_rate}")
|
||||||
|
|
||||||
|
benchmark_start_time = time.perf_counter()
|
||||||
|
tasks = []
|
||||||
async for request in get_request(input_requests, request_rate):
|
async for request in get_request(input_requests, request_rate):
|
||||||
prompt, prompt_len, output_len = request
|
prompt, prompt_len, output_len = request
|
||||||
task = asyncio.create_task(send_request(backend, api_url, prompt,
|
request_func_input = RequestFuncInput(
|
||||||
prompt_len, output_len,
|
model=model_id,
|
||||||
best_of, use_beam_search))
|
prompt=prompt,
|
||||||
tasks.append(task)
|
api_url=api_url,
|
||||||
await asyncio.gather(*tasks)
|
prompt_len=prompt_len,
|
||||||
|
output_len=output_len,
|
||||||
|
best_of=best_of,
|
||||||
|
use_beam_search=use_beam_search,
|
||||||
|
)
|
||||||
|
tasks.append(
|
||||||
|
asyncio.create_task(
|
||||||
|
request_func(request_func_input=request_func_input,
|
||||||
|
pbar=pbar)))
|
||||||
|
outputs = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
if not disable_tqdm:
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||||
|
|
||||||
|
metrics = calculate_metrics(
|
||||||
|
input_requests=input_requests,
|
||||||
|
outputs=outputs,
|
||||||
|
dur_s=benchmark_duration,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Successful requests: {metrics.completed}")
|
||||||
|
print(f"Benchmark duration: {benchmark_duration:2f} s")
|
||||||
|
print(f"Total input tokens: {metrics.total_input}")
|
||||||
|
print(f"Total generated tokens: {metrics.total_output}")
|
||||||
|
print(f"Request throughput: {metrics.request_throughput:.2f} requests/s")
|
||||||
|
print(f"Input token throughput: {metrics.input_throughput:.2f} tokens/s")
|
||||||
|
print(f"Output token throughput: {metrics.output_throughput:.2f} tokens/s")
|
||||||
|
print(f"Mean TTFT: {metrics.mean_ttft_ms:.2f} ms")
|
||||||
|
print(f"Median TTFT: {metrics.median_ttft_ms:.2f} ms")
|
||||||
|
print(f"P99 TTFT: {metrics.p99_ttft_ms:.2f} ms")
|
||||||
|
print(f"Mean TPOT: {metrics.mean_tpot_ms:.2f} ms")
|
||||||
|
print(f"Median TPOT: {metrics.median_tpot_ms:.2f} ms")
|
||||||
|
print(f"P99 TPOT: {metrics.p99_tpot_ms:.2f} ms")
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"duration": benchmark_duration,
|
||||||
|
"completed": metrics.completed,
|
||||||
|
"total_input_tokens": metrics.total_input,
|
||||||
|
"total_output_tokens": metrics.total_output,
|
||||||
|
"request_inthroughput": metrics.request_throughput,
|
||||||
|
"input_throughput": metrics.input_throughput,
|
||||||
|
"output_throughput": metrics.output_throughput,
|
||||||
|
"mean_ttft_ms": metrics.mean_ttft_ms,
|
||||||
|
"median_ttft_ms": metrics.median_ttft_ms,
|
||||||
|
"p99_ttft_ms": metrics.p99_ttft_ms,
|
||||||
|
"mean_tpot_ms": metrics.mean_tpot_ms,
|
||||||
|
"median_tpot_ms": metrics.median_tpot_ms,
|
||||||
|
"p99_tpot_ms": metrics.p99_tpot_ms
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
@@ -176,58 +243,145 @@ def main(args: argparse.Namespace):
|
|||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
api_url = f"http://{args.host}:{args.port}/generate"
|
backend = args.backend
|
||||||
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
model_id = args.model
|
||||||
|
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
||||||
|
|
||||||
|
if args.base_url is not None:
|
||||||
|
api_url = f"{args.base_url}{args.endpoint}"
|
||||||
|
else:
|
||||||
|
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(tokenizer_id,
|
||||||
|
trust_remote_code=args.trust_remote_code)
|
||||||
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||||
|
|
||||||
benchmark_start_time = time.perf_counter()
|
benchmark_result = asyncio.run(
|
||||||
asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of,
|
benchmark(
|
||||||
args.use_beam_search, args.request_rate))
|
backend=backend,
|
||||||
benchmark_end_time = time.perf_counter()
|
api_url=api_url,
|
||||||
benchmark_time = benchmark_end_time - benchmark_start_time
|
model_id=model_id,
|
||||||
print(f"Total time: {benchmark_time:.2f} s")
|
tokenizer=tokenizer,
|
||||||
print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
|
input_requests=input_requests,
|
||||||
|
best_of=args.best_of,
|
||||||
|
use_beam_search=args.use_beam_search,
|
||||||
|
request_rate=args.request_rate,
|
||||||
|
disable_tqdm=args.disable_tqdm,
|
||||||
|
))
|
||||||
|
|
||||||
# Compute the latency statistics.
|
# Save config and results to json
|
||||||
avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
|
if args.save_result:
|
||||||
print(f"Average latency: {avg_latency:.2f} s")
|
result_json = {}
|
||||||
avg_per_token_latency = np.mean([
|
|
||||||
latency / (prompt_len + output_len)
|
# Setup
|
||||||
for prompt_len, output_len, latency in REQUEST_LATENCY
|
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
])
|
result_json["date"] = current_dt
|
||||||
print(f"Average latency per token: {avg_per_token_latency:.2f} s")
|
result_json["backend"] = backend
|
||||||
avg_per_output_token_latency = np.mean([
|
result_json["version"] = args.version
|
||||||
latency / output_len
|
result_json["model_id"] = model_id
|
||||||
for _, output_len, latency in REQUEST_LATENCY
|
result_json["tokenizer_id"] = tokenizer_id
|
||||||
])
|
result_json["best_of"] = args.best_of
|
||||||
print("Average latency per output token: "
|
result_json["use_beam_search"] = args.use_beam_search
|
||||||
f"{avg_per_output_token_latency:.2f} s")
|
result_json["num_prompts"] = args.num_prompts
|
||||||
|
|
||||||
|
# Traffic
|
||||||
|
result_json["request_rate"] = (
|
||||||
|
args.request_rate if args.request_rate < float("inf") else "inf")
|
||||||
|
|
||||||
|
# Merge with benchmark result
|
||||||
|
result_json = {**result_json, **benchmark_result}
|
||||||
|
|
||||||
|
# Save to file
|
||||||
|
base_model_id = model_id.split("/")[-1]
|
||||||
|
file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
|
||||||
|
with open(file_name, "w") as outfile:
|
||||||
|
json.dump(result_json, outfile)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Benchmark the online serving throughput.")
|
description="Benchmark the online serving throughput.")
|
||||||
parser.add_argument("--backend", type=str, default="vllm",
|
parser.add_argument(
|
||||||
choices=["vllm", "tgi"])
|
"--backend",
|
||||||
|
type=str,
|
||||||
|
default="vllm",
|
||||||
|
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--version",
|
||||||
|
type=str,
|
||||||
|
default="N/A",
|
||||||
|
help="Version of the serving backend/engine.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--base-url",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Server or API base url if not using http host and port.",
|
||||||
|
)
|
||||||
parser.add_argument("--host", type=str, default="localhost")
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
parser.add_argument("--port", type=int, default=8000)
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
parser.add_argument("--dataset", type=str, required=True,
|
parser.add_argument(
|
||||||
|
"--endpoint",
|
||||||
|
type=str,
|
||||||
|
default="/generate",
|
||||||
|
help="API endpoint.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--dataset",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
help="Path to the dataset.")
|
help="Path to the dataset.")
|
||||||
parser.add_argument("--tokenizer", type=str, required=True,
|
parser.add_argument(
|
||||||
help="Name or path of the tokenizer.")
|
"--model",
|
||||||
parser.add_argument("--best-of", type=int, default=1,
|
type=str,
|
||||||
help="Generates `best_of` sequences per prompt and "
|
required=True,
|
||||||
"returns the best one.")
|
help="Name of the model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer",
|
||||||
|
type=str,
|
||||||
|
help=
|
||||||
|
"Name or path of the tokenizer, if not using the default model tokenizer.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--best-of",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Generates `best_of` sequences per prompt and "
|
||||||
|
"returns the best one.",
|
||||||
|
)
|
||||||
parser.add_argument("--use-beam-search", action="store_true")
|
parser.add_argument("--use-beam-search", action="store_true")
|
||||||
parser.add_argument("--num-prompts", type=int, default=1000,
|
parser.add_argument(
|
||||||
help="Number of prompts to process.")
|
"--num-prompts",
|
||||||
parser.add_argument("--request-rate", type=float, default=float("inf"),
|
type=int,
|
||||||
help="Number of requests per second. If this is inf, "
|
default=1000,
|
||||||
"then all the requests are sent at time 0. "
|
help="Number of prompts to process.",
|
||||||
"Otherwise, we use Poisson process to synthesize "
|
)
|
||||||
"the request arrival times.")
|
parser.add_argument(
|
||||||
|
"--request-rate",
|
||||||
|
type=float,
|
||||||
|
default=float("inf"),
|
||||||
|
help="Number of requests per second. If this is inf, "
|
||||||
|
"then all the requests are sent at time 0. "
|
||||||
|
"Otherwise, we use Poisson process to synthesize "
|
||||||
|
"the request arrival times.",
|
||||||
|
)
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument('--trust-remote-code', action='store_true',
|
parser.add_argument(
|
||||||
help='trust remote code from huggingface')
|
"--trust-remote-code",
|
||||||
|
action="store_true",
|
||||||
|
help="Trust remote code from huggingface",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-tqdm",
|
||||||
|
action="store_true",
|
||||||
|
help="Specify to disable tqdm progress bar.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-result",
|
||||||
|
action="store_true",
|
||||||
|
help="Specify to save benchmark results to a json file",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@@ -6,18 +6,20 @@ import time
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
|
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||||
|
PreTrainedTokenizerBase)
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def sample_requests(
|
def sample_requests(
|
||||||
dataset_path: str,
|
dataset_path: str,
|
||||||
num_requests: int,
|
num_requests: int,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
fixed_output_len: Optional[int],
|
||||||
) -> List[Tuple[str, int, int]]:
|
) -> List[Tuple[str, int, int]]:
|
||||||
|
if fixed_output_len is not None and fixed_output_len < 4:
|
||||||
|
raise ValueError("output_len too small")
|
||||||
|
|
||||||
# Load the dataset.
|
# Load the dataset.
|
||||||
with open(dataset_path) as f:
|
with open(dataset_path) as f:
|
||||||
dataset = json.load(f)
|
dataset = json.load(f)
|
||||||
@@ -35,6 +37,8 @@ def sample_requests(
|
|||||||
tokenized_dataset = []
|
tokenized_dataset = []
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
output_len = len(completion_token_ids[i])
|
output_len = len(completion_token_ids[i])
|
||||||
|
if fixed_output_len is not None:
|
||||||
|
output_len = fixed_output_len
|
||||||
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
|
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
|
||||||
|
|
||||||
# Filter out too long sequences.
|
# Filter out too long sequences.
|
||||||
@@ -65,7 +69,12 @@ def run_vllm(
|
|||||||
use_beam_search: bool,
|
use_beam_search: bool,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
|
max_model_len: Optional[int],
|
||||||
|
enforce_eager: bool,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
device: str,
|
||||||
) -> float:
|
) -> float:
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@@ -74,6 +83,10 @@ def run_vllm(
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
@@ -94,7 +107,7 @@ def run_vllm(
|
|||||||
)
|
)
|
||||||
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
# FIXME(woosuk): Do use internal method.
|
# FIXME(woosuk): Do not use internal method.
|
||||||
llm._run_engine(use_tqdm=True)
|
llm._run_engine(use_tqdm=True)
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
return end - start
|
return end - start
|
||||||
@@ -160,25 +173,53 @@ def run_hf(
|
|||||||
return end - start
|
return end - start
|
||||||
|
|
||||||
|
|
||||||
|
def run_mii(
|
||||||
|
requests: List[Tuple[str, int, int]],
|
||||||
|
model: str,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
output_len: int,
|
||||||
|
) -> float:
|
||||||
|
from mii import pipeline
|
||||||
|
llm = pipeline(model, tensor_parallel=tensor_parallel_size)
|
||||||
|
prompts = [prompt for prompt, _, _ in requests]
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
llm(prompts, max_new_tokens=output_len)
|
||||||
|
end = time.perf_counter()
|
||||||
|
return end - start
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
print(args)
|
print(args)
|
||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
|
|
||||||
# Sample the requests.
|
# Sample the requests.
|
||||||
tokenizer = get_tokenizer(args.tokenizer,
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
trust_remote_code=args.trust_remote_code)
|
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
if args.dataset is None:
|
||||||
|
# Synthesize a prompt with the given input length.
|
||||||
|
prompt = "hi" * (args.input_len - 1)
|
||||||
|
requests = [(prompt, args.input_len, args.output_len)
|
||||||
|
for _ in range(args.num_prompts)]
|
||||||
|
else:
|
||||||
|
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
|
||||||
|
args.output_len)
|
||||||
|
|
||||||
if args.backend == "vllm":
|
if args.backend == "vllm":
|
||||||
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
|
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
|
||||||
args.quantization, args.tensor_parallel_size,
|
args.quantization, args.tensor_parallel_size,
|
||||||
args.seed, args.n, args.use_beam_search,
|
args.seed, args.n, args.use_beam_search,
|
||||||
args.trust_remote_code, args.dtype)
|
args.trust_remote_code, args.dtype,
|
||||||
|
args.max_model_len, args.enforce_eager,
|
||||||
|
args.kv_cache_dtype, args.device)
|
||||||
elif args.backend == "hf":
|
elif args.backend == "hf":
|
||||||
assert args.tensor_parallel_size == 1
|
assert args.tensor_parallel_size == 1
|
||||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||||
args.use_beam_search, args.hf_max_batch_size,
|
args.use_beam_search, args.hf_max_batch_size,
|
||||||
args.trust_remote_code)
|
args.trust_remote_code)
|
||||||
|
elif args.backend == "mii":
|
||||||
|
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
|
||||||
|
args.output_len)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown backend: {args.backend}")
|
raise ValueError(f"Unknown backend: {args.backend}")
|
||||||
total_num_tokens = sum(prompt_len + output_len
|
total_num_tokens = sum(prompt_len + output_len
|
||||||
@@ -191,17 +232,26 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
||||||
parser.add_argument("--backend",
|
parser.add_argument("--backend",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["vllm", "hf"],
|
choices=["vllm", "hf", "mii"],
|
||||||
default="vllm")
|
default="vllm")
|
||||||
parser.add_argument("--dataset",
|
parser.add_argument("--dataset",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
default=None,
|
||||||
help="Path to the dataset.")
|
help="Path to the dataset.")
|
||||||
|
parser.add_argument("--input-len",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Input prompt length for each request")
|
||||||
|
parser.add_argument("--output-len",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Output length for each request. Overrides the "
|
||||||
|
"output length from the dataset.")
|
||||||
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
||||||
parser.add_argument("--tokenizer", type=str, default=None)
|
parser.add_argument("--tokenizer", type=str, default=None)
|
||||||
parser.add_argument('--quantization',
|
parser.add_argument('--quantization',
|
||||||
'-q',
|
'-q',
|
||||||
choices=['awq', None],
|
choices=['awq', 'gptq', 'squeezellm', None],
|
||||||
default=None)
|
default=None)
|
||||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||||
parser.add_argument("--n",
|
parser.add_argument("--n",
|
||||||
@@ -221,6 +271,12 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument('--trust-remote-code',
|
parser.add_argument('--trust-remote-code',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='trust remote code from huggingface')
|
help='trust remote code from huggingface')
|
||||||
|
parser.add_argument(
|
||||||
|
'--max-model-len',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help='Maximum length of a sequence (including prompt and output). '
|
||||||
|
'If None, will be derived from the model.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--dtype',
|
'--dtype',
|
||||||
type=str,
|
type=str,
|
||||||
@@ -230,7 +286,30 @@ if __name__ == "__main__":
|
|||||||
'The "auto" option will use FP16 precision '
|
'The "auto" option will use FP16 precision '
|
||||||
'for FP32 and FP16 models, and BF16 precision '
|
'for FP32 and FP16 models, and BF16 precision '
|
||||||
'for BF16 models.')
|
'for BF16 models.')
|
||||||
|
parser.add_argument("--enforce-eager",
|
||||||
|
action="store_true",
|
||||||
|
help="enforce eager execution")
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-cache-dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["auto", "fp8_e5m2"],
|
||||||
|
default="auto",
|
||||||
|
help=
|
||||||
|
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cuda",
|
||||||
|
choices=["cuda"],
|
||||||
|
help='device type for vLLM execution, supporting CUDA only currently.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
if args.tokenizer is None:
|
||||||
|
args.tokenizer = args.model
|
||||||
|
if args.dataset is None:
|
||||||
|
assert args.input_len is not None
|
||||||
|
assert args.output_len is not None
|
||||||
|
else:
|
||||||
|
assert args.input_len is None
|
||||||
|
|
||||||
if args.backend == "vllm":
|
if args.backend == "vllm":
|
||||||
if args.hf_max_batch_size is not None:
|
if args.hf_max_batch_size is not None:
|
||||||
@@ -240,7 +319,18 @@ if __name__ == "__main__":
|
|||||||
raise ValueError("HF max batch size is required for HF backend.")
|
raise ValueError("HF max batch size is required for HF backend.")
|
||||||
if args.quantization is not None:
|
if args.quantization is not None:
|
||||||
raise ValueError("Quantization is only for vLLM backend.")
|
raise ValueError("Quantization is only for vLLM backend.")
|
||||||
if args.tokenizer is None:
|
elif args.backend == "mii":
|
||||||
args.tokenizer = args.model
|
if args.dtype != "auto":
|
||||||
|
raise ValueError("dtype must be auto for MII backend.")
|
||||||
|
if args.n != 1:
|
||||||
|
raise ValueError("n must be 1 for MII backend.")
|
||||||
|
if args.use_beam_search:
|
||||||
|
raise ValueError("Beam search is not supported for MII backend.")
|
||||||
|
if args.quantization is not None:
|
||||||
|
raise ValueError("Quantization is only for vLLM backend.")
|
||||||
|
if args.hf_max_batch_size is not None:
|
||||||
|
raise ValueError("HF max batch size is only for HF backend.")
|
||||||
|
if args.tokenizer != args.model:
|
||||||
|
raise ValueError("Tokenizer must be the same as the model for MII "
|
||||||
|
"backend.")
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
172
benchmarks/kernels/benchmark_mixtral_moe.py
Normal file
172
benchmarks/kernels/benchmark_mixtral_moe.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import triton
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
method = fused_moe
|
||||||
|
for bs in [
|
||||||
|
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
||||||
|
2048, 3072, 4096
|
||||||
|
]:
|
||||||
|
run_grid(bs, method=method)
|
||||||
|
|
||||||
|
|
||||||
|
def run_grid(bs, method):
|
||||||
|
d_model = 4096
|
||||||
|
num_total_experts = 8
|
||||||
|
top_k = 2
|
||||||
|
tp_size = 2
|
||||||
|
model_intermediate_size = 14336
|
||||||
|
num_layers = 32
|
||||||
|
num_calls = 100
|
||||||
|
|
||||||
|
num_warmup_trials = 1
|
||||||
|
num_trials = 1
|
||||||
|
|
||||||
|
configs = []
|
||||||
|
if bs <= 16:
|
||||||
|
BLOCK_SIZES_M = [16]
|
||||||
|
elif bs <= 32:
|
||||||
|
BLOCK_SIZES_M = [16, 32]
|
||||||
|
elif bs <= 64:
|
||||||
|
BLOCK_SIZES_M = [16, 32, 64]
|
||||||
|
elif bs <= 128:
|
||||||
|
BLOCK_SIZES_M = [16, 32, 64, 128]
|
||||||
|
else:
|
||||||
|
BLOCK_SIZES_M = [16, 32, 64, 128, 256]
|
||||||
|
|
||||||
|
for block_size_n in [32, 64, 128, 256]:
|
||||||
|
for block_size_m in BLOCK_SIZES_M:
|
||||||
|
for block_size_k in [64, 128, 256]:
|
||||||
|
for group_size_m in [1, 16, 32, 64]:
|
||||||
|
for num_warps in [4, 8]:
|
||||||
|
configs.append({
|
||||||
|
"BLOCK_SIZE_M": block_size_m,
|
||||||
|
"BLOCK_SIZE_N": block_size_n,
|
||||||
|
"BLOCK_SIZE_K": block_size_k,
|
||||||
|
"GROUP_SIZE_M": group_size_m,
|
||||||
|
"num_warps": num_warps,
|
||||||
|
"num_stages": 4,
|
||||||
|
})
|
||||||
|
|
||||||
|
best_config = None
|
||||||
|
best_time_us = 1e20
|
||||||
|
|
||||||
|
for config in configs:
|
||||||
|
print(f'{tp_size=} {bs=}')
|
||||||
|
print(f'{config}')
|
||||||
|
# warmup
|
||||||
|
print(f'warming up')
|
||||||
|
try:
|
||||||
|
for _ in range(num_warmup_trials):
|
||||||
|
run_timing(
|
||||||
|
num_calls=num_calls,
|
||||||
|
bs=bs,
|
||||||
|
d_model=d_model,
|
||||||
|
num_total_experts=num_total_experts,
|
||||||
|
top_k=top_k,
|
||||||
|
tp_size=tp_size,
|
||||||
|
model_intermediate_size=model_intermediate_size,
|
||||||
|
method=method,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
except triton.runtime.autotuner.OutOfResources:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# trial
|
||||||
|
print(f'benchmarking')
|
||||||
|
for _ in range(num_trials):
|
||||||
|
kernel_dur_ms = run_timing(
|
||||||
|
num_calls=num_calls,
|
||||||
|
bs=bs,
|
||||||
|
d_model=d_model,
|
||||||
|
num_total_experts=num_total_experts,
|
||||||
|
top_k=top_k,
|
||||||
|
tp_size=tp_size,
|
||||||
|
model_intermediate_size=model_intermediate_size,
|
||||||
|
method=method,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
kernel_dur_us = 1000 * kernel_dur_ms
|
||||||
|
model_dur_ms = kernel_dur_ms * num_layers
|
||||||
|
|
||||||
|
if kernel_dur_us < best_time_us:
|
||||||
|
best_config = config
|
||||||
|
best_time_us = kernel_dur_us
|
||||||
|
|
||||||
|
print(
|
||||||
|
f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f} {bs=} {tp_size=} {top_k=} {num_total_experts=} {d_model=} {model_intermediate_size=} {num_layers=}'
|
||||||
|
)
|
||||||
|
|
||||||
|
print("best_time_us", best_time_us)
|
||||||
|
print("best_config", best_config)
|
||||||
|
|
||||||
|
filename = "/tmp/config.jsonl"
|
||||||
|
print(f"writing config to file {filename}")
|
||||||
|
with open(filename, "a") as f:
|
||||||
|
f.write(json.dumps({str(bs): best_config}) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
|
||||||
|
top_k: int, tp_size: int, model_intermediate_size: int, method,
|
||||||
|
config) -> float:
|
||||||
|
shard_intermediate_size = model_intermediate_size // tp_size
|
||||||
|
|
||||||
|
hidden_states = torch.rand(
|
||||||
|
(bs, d_model),
|
||||||
|
device="cuda:0",
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
ws = torch.rand(
|
||||||
|
(num_total_experts, 2 * shard_intermediate_size, d_model),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
w2s = torch.rand(
|
||||||
|
(num_total_experts, d_model, shard_intermediate_size),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
gating_output = F.softmax(torch.rand(
|
||||||
|
(num_calls, bs, num_total_experts),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
|
dim=-1)
|
||||||
|
|
||||||
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
|
start_event.record()
|
||||||
|
for i in range(num_calls):
|
||||||
|
hidden_states = method(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
w1=ws,
|
||||||
|
w2=w2s,
|
||||||
|
gating_output=gating_output[i],
|
||||||
|
topk=2,
|
||||||
|
renormalize=True,
|
||||||
|
inplace=True,
|
||||||
|
override_config=config,
|
||||||
|
)
|
||||||
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
|
||||||
|
dur_ms = start_event.elapsed_time(end_event) / num_calls
|
||||||
|
return dur_ms
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
|
from typing import Optional
|
||||||
import argparse
|
import argparse
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import attention_ops
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
|
||||||
|
from vllm._C import ops
|
||||||
|
|
||||||
NUM_BLOCKS = 1024
|
NUM_BLOCKS = 1024
|
||||||
PARTITION_SIZE = 512
|
PARTITION_SIZE = 512
|
||||||
@@ -23,33 +25,32 @@ def main(
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
seed: int,
|
seed: int,
|
||||||
do_profile: bool,
|
do_profile: bool,
|
||||||
|
device: str = "cuda",
|
||||||
|
kv_cache_dtype: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
torch.random.manual_seed(seed)
|
torch.random.manual_seed(seed)
|
||||||
torch.cuda.manual_seed(seed)
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
scale = float(1.0 / (head_size**0.5))
|
scale = float(1.0 / (head_size**0.5))
|
||||||
query = torch.empty(num_seqs,
|
query = torch.empty(num_seqs,
|
||||||
num_query_heads,
|
num_query_heads,
|
||||||
head_size,
|
head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device="cuda")
|
device=device)
|
||||||
query.uniform_(-scale, scale)
|
query.uniform_(-scale, scale)
|
||||||
|
|
||||||
assert num_query_heads % num_kv_heads == 0
|
assert num_query_heads % num_kv_heads == 0
|
||||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
|
||||||
head_mapping = torch.repeat_interleave(
|
|
||||||
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
|
|
||||||
num_queries_per_kv)
|
|
||||||
alibi_slopes = None
|
alibi_slopes = None
|
||||||
if use_alibi:
|
if use_alibi:
|
||||||
alibi_slopes = torch.randn(num_query_heads,
|
alibi_slopes = torch.randn(num_query_heads,
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
device="cuda")
|
device=device)
|
||||||
|
|
||||||
context_lens = [context_len for _ in range(num_seqs)]
|
context_lens = [context_len for _ in range(num_seqs)]
|
||||||
max_context_len = max(context_lens)
|
max_context_len = max(context_lens)
|
||||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
context_lens = torch.tensor(context_lens, dtype=torch.int, device=device)
|
||||||
|
|
||||||
# Create the block tables.
|
# Create the block tables.
|
||||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||||
@@ -60,18 +61,18 @@ def main(
|
|||||||
for _ in range(max_num_blocks_per_seq)
|
for _ in range(max_num_blocks_per_seq)
|
||||||
]
|
]
|
||||||
block_tables.append(block_table)
|
block_tables.append(block_table)
|
||||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
|
block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)
|
||||||
|
|
||||||
# Create the KV cache.
|
# Create the KV cache.
|
||||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
|
||||||
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
|
block_size,
|
||||||
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda")
|
1,
|
||||||
key_cache.uniform_(-scale, scale)
|
num_kv_heads,
|
||||||
value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size)
|
head_size,
|
||||||
value_cache = torch.empty(size=value_cache_shape,
|
kv_cache_dtype,
|
||||||
dtype=dtype,
|
dtype,
|
||||||
device="cuda")
|
device=device)
|
||||||
value_cache.uniform_(-scale, scale)
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
|
|
||||||
# Prepare for the paged attention kernel.
|
# Prepare for the paged attention kernel.
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
@@ -90,7 +91,7 @@ def main(
|
|||||||
)
|
)
|
||||||
max_logits = torch.empty_like(exp_sums)
|
max_logits = torch.empty_like(exp_sums)
|
||||||
|
|
||||||
def run_benchmark(num_iters: int, profile: bool = False) -> float:
|
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
if profile:
|
if profile:
|
||||||
torch.cuda.cudart().cudaProfilerStart()
|
torch.cuda.cudart().cudaProfilerStart()
|
||||||
@@ -98,21 +99,22 @@ def main(
|
|||||||
|
|
||||||
for _ in range(num_iters):
|
for _ in range(num_iters):
|
||||||
if version == "v1":
|
if version == "v1":
|
||||||
attention_ops.paged_attention_v1(
|
ops.paged_attention_v1(
|
||||||
output,
|
output,
|
||||||
query,
|
query,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
head_mapping,
|
num_kv_heads,
|
||||||
scale,
|
scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
context_lens,
|
context_lens,
|
||||||
block_size,
|
block_size,
|
||||||
max_context_len,
|
max_context_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
|
kv_cache_dtype,
|
||||||
)
|
)
|
||||||
elif version == "v2":
|
elif version == "v2":
|
||||||
attention_ops.paged_attention_v2(
|
ops.paged_attention_v2(
|
||||||
output,
|
output,
|
||||||
exp_sums,
|
exp_sums,
|
||||||
max_logits,
|
max_logits,
|
||||||
@@ -120,13 +122,14 @@ def main(
|
|||||||
query,
|
query,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
head_mapping,
|
num_kv_heads,
|
||||||
scale,
|
scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
context_lens,
|
context_lens,
|
||||||
block_size,
|
block_size,
|
||||||
max_context_len,
|
max_context_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
|
kv_cache_dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid version: {version}")
|
raise ValueError(f"Invalid version: {version}")
|
||||||
@@ -139,6 +142,7 @@ def main(
|
|||||||
|
|
||||||
# Warmup.
|
# Warmup.
|
||||||
print("Warming up...")
|
print("Warming up...")
|
||||||
|
run_benchmark = run_cuda_benchmark
|
||||||
run_benchmark(num_iters=3, profile=False)
|
run_benchmark(num_iters=3, profile=False)
|
||||||
|
|
||||||
# Benchmark.
|
# Benchmark.
|
||||||
@@ -172,16 +176,19 @@ if __name__ == '__main__':
|
|||||||
default="half")
|
default="half")
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--profile", action="store_true")
|
parser.add_argument("--profile", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-cache-dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["auto", "fp8_e5m2"],
|
||||||
|
default="auto",
|
||||||
|
help=
|
||||||
|
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||||
|
parser.add_argument("--device", type=str, choices=["cuda"], default="cuda")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
if args.num_query_heads % args.num_kv_heads != 0:
|
if args.num_query_heads % args.num_kv_heads != 0:
|
||||||
raise ValueError("num_query_heads must be divisible by num_kv_heads")
|
raise ValueError("num_query_heads must be divisible by num_kv_heads")
|
||||||
dtype_to_torch_dtype = {
|
|
||||||
"half": torch.half,
|
|
||||||
"bfloat16": torch.bfloat16,
|
|
||||||
"float": torch.float,
|
|
||||||
}
|
|
||||||
main(
|
main(
|
||||||
version=args.version,
|
version=args.version,
|
||||||
num_seqs=args.batch_size,
|
num_seqs=args.batch_size,
|
||||||
@@ -191,7 +198,8 @@ if __name__ == '__main__':
|
|||||||
head_size=args.head_size,
|
head_size=args.head_size,
|
||||||
block_size=args.block_size,
|
block_size=args.block_size,
|
||||||
use_alibi=args.use_alibi,
|
use_alibi=args.use_alibi,
|
||||||
dtype=dtype_to_torch_dtype[args.dtype],
|
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
do_profile=args.profile,
|
do_profile=args.profile,
|
||||||
|
kv_cache_dtype=args.kv_cache_dtype,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ TOKENS=$2
|
|||||||
|
|
||||||
docker run --gpus all --shm-size 1g -p $PORT:80 \
|
docker run --gpus all --shm-size 1g -p $PORT:80 \
|
||||||
-v $PWD/data:/data \
|
-v $PWD/data:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:0.8 \
|
ghcr.io/huggingface/text-generation-inference:1.4.0 \
|
||||||
--model-id $MODEL \
|
--model-id $MODEL \
|
||||||
--sharded false \
|
--sharded false \
|
||||||
--max-input-length 1024 \
|
--max-input-length 1024 \
|
||||||
|
|||||||
@@ -1,28 +0,0 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
void silu_and_mul(
|
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input);
|
|
||||||
|
|
||||||
void gelu_new(
|
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input);
|
|
||||||
|
|
||||||
void gelu_fast(
|
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input);
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def(
|
|
||||||
"silu_and_mul",
|
|
||||||
&silu_and_mul,
|
|
||||||
"Activation function used in SwiGLU.");
|
|
||||||
m.def(
|
|
||||||
"gelu_new",
|
|
||||||
&gelu_new,
|
|
||||||
"GELU implementation used in GPT-2.");
|
|
||||||
m.def(
|
|
||||||
"gelu_fast",
|
|
||||||
&gelu_fast,
|
|
||||||
"Approximate GELU implementation.");
|
|
||||||
}
|
|
||||||
@@ -1,50 +1,76 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "cuda_compat.h"
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
|
// Activation and gating kernel template.
|
||||||
|
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||||
|
__global__ void act_and_mul_kernel(
|
||||||
|
scalar_t* __restrict__ out, // [..., d]
|
||||||
|
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||||
|
const int d) {
|
||||||
|
const int64_t token_idx = blockIdx.x;
|
||||||
|
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||||
|
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
||||||
|
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
||||||
|
out[token_idx * d + idx] = ACT_FN(x) * y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__device__ __forceinline__ T silu(const T& x) {
|
__device__ __forceinline__ T silu_kernel(const T& x) {
|
||||||
// x * sigmoid(x)
|
// x * sigmoid(x)
|
||||||
return (T) (((float) x) / (1.0f + expf((float) -x)));
|
return (T) (((float) x) / (1.0f + expf((float) -x)));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename scalar_t>
|
template<typename T>
|
||||||
__global__ void silu_and_mul_kernel(
|
__device__ __forceinline__ T gelu_kernel(const T& x) {
|
||||||
scalar_t* __restrict__ out, // [num_tokens, d]
|
// Equivalent to PyTorch GELU with 'none' approximation.
|
||||||
const scalar_t* __restrict__ input, // [num_tokens, 2, d]
|
// Refer to:
|
||||||
const int d) {
|
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L38
|
||||||
const int token_idx = blockIdx.x;
|
const float f = (float) x;
|
||||||
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
constexpr float ALPHA = M_SQRT1_2;
|
||||||
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
|
return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
|
||||||
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
|
|
||||||
out[token_idx * d + idx] = silu(x) * y;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void silu_and_mul(
|
// Launch activation and gating kernel.
|
||||||
torch::Tensor& out, // [num_tokens, d]
|
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
|
||||||
torch::Tensor& input) // [num_tokens, 2 * d]
|
int d = input.size(-1) / 2; \
|
||||||
{
|
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||||
int num_tokens = input.size(0);
|
dim3 grid(num_tokens); \
|
||||||
int d = input.size(1) / 2;
|
dim3 block(std::min(d, 1024)); \
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||||
dim3 grid(num_tokens);
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||||
dim3 block(std::min(d, 1024));
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
input.scalar_type(), \
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
"act_and_mul_kernel", \
|
||||||
input.scalar_type(),
|
[&] { \
|
||||||
"silu_and_mul_kernel",
|
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
||||||
[&] {
|
out.data_ptr<scalar_t>(), \
|
||||||
vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
input.data_ptr<scalar_t>(), \
|
||||||
out.data_ptr<scalar_t>(),
|
d); \
|
||||||
input.data_ptr<scalar_t>(),
|
|
||||||
d);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
void silu_and_mul(
|
||||||
|
torch::Tensor& out, // [..., d]
|
||||||
|
torch::Tensor& input) // [..., 2 * d]
|
||||||
|
{
|
||||||
|
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
||||||
|
}
|
||||||
|
|
||||||
|
void gelu_and_mul(
|
||||||
|
torch::Tensor& out, // [..., d]
|
||||||
|
torch::Tensor& input) // [..., 2 * d]
|
||||||
|
{
|
||||||
|
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
@@ -52,12 +78,12 @@ namespace vllm {
|
|||||||
// Element-wise activation kernel template.
|
// Element-wise activation kernel template.
|
||||||
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||||
__global__ void activation_kernel(
|
__global__ void activation_kernel(
|
||||||
scalar_t* __restrict__ out, // [num_tokens, d]
|
scalar_t* __restrict__ out, // [..., d]
|
||||||
const scalar_t* __restrict__ input, // [num_tokens, d]
|
const scalar_t* __restrict__ input, // [..., d]
|
||||||
const int d) {
|
const int d) {
|
||||||
const int token_idx = blockIdx.x;
|
const int64_t token_idx = blockIdx.x;
|
||||||
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||||
const scalar_t x = __ldg(&input[token_idx * d + idx]);
|
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
|
||||||
out[token_idx * d + idx] = ACT_FN(x);
|
out[token_idx * d + idx] = ACT_FN(x);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -66,10 +92,11 @@ __global__ void activation_kernel(
|
|||||||
|
|
||||||
// Launch element-wise activation kernel.
|
// Launch element-wise activation kernel.
|
||||||
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
||||||
int num_tokens = input.size(0); \
|
int d = input.size(-1); \
|
||||||
int d = input.size(1); \
|
int64_t num_tokens = input.numel() / d; \
|
||||||
dim3 grid(num_tokens); \
|
dim3 grid(num_tokens); \
|
||||||
dim3 block(std::min(d, 1024)); \
|
dim3 block(std::min(d, 1024)); \
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||||
input.scalar_type(), \
|
input.scalar_type(), \
|
||||||
@@ -100,15 +127,15 @@ __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
|
|||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void gelu_new(
|
void gelu_new(
|
||||||
torch::Tensor& out, // [num_tokens, d]
|
torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& input) // [num_tokens, d]
|
torch::Tensor& input) // [..., d]
|
||||||
{
|
{
|
||||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
||||||
}
|
}
|
||||||
|
|
||||||
void gelu_fast(
|
void gelu_fast(
|
||||||
torch::Tensor& out, // [num_tokens, d]
|
torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& input) // [num_tokens, d]
|
torch::Tensor& input) // [..., d]
|
||||||
{
|
{
|
||||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,42 +0,0 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
#include <c10/util/Optional.h>
|
|
||||||
|
|
||||||
void paged_attention_v1(
|
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& query,
|
|
||||||
torch::Tensor& key_cache,
|
|
||||||
torch::Tensor& value_cache,
|
|
||||||
torch::Tensor& head_mapping,
|
|
||||||
float scale,
|
|
||||||
torch::Tensor& block_tables,
|
|
||||||
torch::Tensor& context_lens,
|
|
||||||
int block_size,
|
|
||||||
int max_context_len,
|
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes);
|
|
||||||
|
|
||||||
void paged_attention_v2(
|
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& exp_sums,
|
|
||||||
torch::Tensor& max_logits,
|
|
||||||
torch::Tensor& tmp_out,
|
|
||||||
torch::Tensor& query,
|
|
||||||
torch::Tensor& key_cache,
|
|
||||||
torch::Tensor& value_cache,
|
|
||||||
torch::Tensor& head_mapping,
|
|
||||||
float scale,
|
|
||||||
torch::Tensor& block_tables,
|
|
||||||
torch::Tensor& context_lens,
|
|
||||||
int block_size,
|
|
||||||
int max_context_len,
|
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes);
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def(
|
|
||||||
"paged_attention_v1",
|
|
||||||
&paged_attention_v1,
|
|
||||||
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
|
||||||
m.def(
|
|
||||||
"paged_attention_v2",
|
|
||||||
&paged_attention_v2,
|
|
||||||
"PagedAttention V2.");
|
|
||||||
}
|
|
||||||
@@ -4,3 +4,4 @@
|
|||||||
#include "dtype_float16.cuh"
|
#include "dtype_float16.cuh"
|
||||||
#include "dtype_float32.cuh"
|
#include "dtype_float32.cuh"
|
||||||
#include "dtype_bfloat16.cuh"
|
#include "dtype_bfloat16.cuh"
|
||||||
|
#include "dtype_fp8_e5m2.cuh"
|
||||||
|
|||||||
@@ -15,15 +15,27 @@
|
|||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
#include "attention_dtypes.h"
|
#include "attention_dtypes.h"
|
||||||
#include "attention_utils.cuh"
|
#include "attention_utils.cuh"
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
#define WARP_SIZE 32
|
#define WARP_SIZE 32
|
||||||
|
#else
|
||||||
|
#define WARP_SIZE warpSize
|
||||||
|
#endif
|
||||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||||
@@ -40,7 +52,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
|
|||||||
// Compute the sum per warp.
|
// Compute the sum per warp.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||||
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Warp leaders store the data to shared memory.
|
// Warp leaders store the data to shared memory.
|
||||||
@@ -59,29 +71,31 @@ inline __device__ float block_sum(float* red_smem, float sum) {
|
|||||||
// Parallel reduction inside the warp.
|
// Parallel reduction inside the warp.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||||
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Broadcast to other threads.
|
// Broadcast to other threads.
|
||||||
return __shfl_sync(uint32_t(-1), sum, 0);
|
return VLLM_SHFL_SYNC(sum, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(woosuk): Merge the last two dimensions of the grid.
|
// TODO(woosuk): Merge the last two dimensions of the grid.
|
||||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||||
template<
|
template<
|
||||||
typename scalar_t,
|
typename scalar_t,
|
||||||
|
typename cache_t,
|
||||||
int HEAD_SIZE,
|
int HEAD_SIZE,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
int NUM_THREADS,
|
int NUM_THREADS,
|
||||||
|
bool IS_FP8_E5M2_KV_CACHE,
|
||||||
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
||||||
__device__ void paged_attention_kernel(
|
__device__ void paged_attention_kernel(
|
||||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||||
const int* __restrict__ head_mapping, // [num_heads]
|
const int num_kv_heads, // [num_heads]
|
||||||
const float scale,
|
const float scale,
|
||||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
const int* __restrict__ context_lens, // [num_seqs]
|
const int* __restrict__ context_lens, // [num_seqs]
|
||||||
@@ -124,7 +138,8 @@ __device__ void paged_attention_kernel(
|
|||||||
|
|
||||||
const int head_idx = blockIdx.x;
|
const int head_idx = blockIdx.x;
|
||||||
const int num_heads = gridDim.x;
|
const int num_heads = gridDim.x;
|
||||||
const int kv_head_idx = head_mapping[head_idx];
|
const int num_queries_per_kv = num_heads / num_kv_heads;
|
||||||
|
const int kv_head_idx = head_idx / num_queries_per_kv;
|
||||||
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
||||||
|
|
||||||
// A vector type to store a part of a key or a query.
|
// A vector type to store a part of a key or a query.
|
||||||
@@ -135,6 +150,9 @@ __device__ void paged_attention_kernel(
|
|||||||
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
||||||
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||||
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
|
||||||
|
#endif
|
||||||
|
|
||||||
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
||||||
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
||||||
@@ -166,7 +184,7 @@ __device__ void paged_attention_kernel(
|
|||||||
|
|
||||||
// x == THREAD_GROUP_SIZE * VEC_SIZE
|
// x == THREAD_GROUP_SIZE * VEC_SIZE
|
||||||
// Each thread group fetches x elements from the key at a time.
|
// Each thread group fetches x elements from the key at a time.
|
||||||
constexpr int x = 16 / sizeof(scalar_t);
|
constexpr int x = 16 / sizeof(cache_t);
|
||||||
float qk_max = -FLT_MAX;
|
float qk_max = -FLT_MAX;
|
||||||
|
|
||||||
// Iterate over the key blocks.
|
// Iterate over the key blocks.
|
||||||
@@ -175,7 +193,10 @@ __device__ void paged_attention_kernel(
|
|||||||
// dot product with the query.
|
// dot product with the query.
|
||||||
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||||
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
||||||
const int physical_block_number = block_table[block_idx];
|
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
|
||||||
|
// because int32 can lead to overflow when this variable is multiplied by large numbers
|
||||||
|
// (e.g., kv_block_stride).
|
||||||
|
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
||||||
|
|
||||||
// Load a key to registers.
|
// Load a key to registers.
|
||||||
// Each thread in a thread group has a different part of the key.
|
// Each thread in a thread group has a different part of the key.
|
||||||
@@ -189,13 +210,23 @@ __device__ void paged_attention_kernel(
|
|||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
||||||
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
||||||
+ kv_head_idx * kv_head_stride
|
+ kv_head_idx * kv_head_stride
|
||||||
+ physical_block_offset * x;
|
+ physical_block_offset * x;
|
||||||
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
||||||
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
||||||
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
||||||
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
if constexpr (IS_FP8_E5M2_KV_CACHE) {
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||||
|
// Vector conversion from Quant_vec to K_vec.
|
||||||
|
k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
|
||||||
|
#else
|
||||||
|
assert(false);
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute dot product.
|
// Compute dot product.
|
||||||
@@ -220,7 +251,7 @@ __device__ void paged_attention_kernel(
|
|||||||
// The 0-th thread of each thread group already has its max qk value.
|
// The 0-th thread of each thread group already has its max qk value.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
||||||
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||||
}
|
}
|
||||||
if (lane == 0) {
|
if (lane == 0) {
|
||||||
red_smem[warp_idx] = qk_max;
|
red_smem[warp_idx] = qk_max;
|
||||||
@@ -232,10 +263,10 @@ __device__ void paged_attention_kernel(
|
|||||||
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||||
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||||
}
|
}
|
||||||
// Broadcast the max qk value to all threads.
|
// Broadcast the max qk value to all threads.
|
||||||
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
|
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
|
||||||
|
|
||||||
// Get the sum of the exp values.
|
// Get the sum of the exp values.
|
||||||
float exp_sum = 0.f;
|
float exp_sum = 0.f;
|
||||||
@@ -269,6 +300,9 @@ __device__ void paged_attention_kernel(
|
|||||||
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
||||||
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||||
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
|
||||||
|
#endif
|
||||||
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
||||||
|
|
||||||
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
||||||
@@ -285,20 +319,34 @@ __device__ void paged_attention_kernel(
|
|||||||
scalar_t zero_value;
|
scalar_t zero_value;
|
||||||
zero(zero_value);
|
zero(zero_value);
|
||||||
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
||||||
const int physical_block_number = block_table[block_idx];
|
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
|
||||||
|
// because int32 can lead to overflow when this variable is multiplied by large numbers
|
||||||
|
// (e.g., kv_block_stride).
|
||||||
|
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
||||||
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
||||||
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
||||||
L_vec logits_vec;
|
L_vec logits_vec;
|
||||||
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
|
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
|
||||||
|
|
||||||
const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
||||||
+ kv_head_idx * kv_head_stride;
|
+ kv_head_idx * kv_head_stride;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||||
if (row_idx < HEAD_SIZE) {
|
if (row_idx < HEAD_SIZE) {
|
||||||
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
||||||
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
V_vec v_vec;
|
||||||
|
if constexpr (IS_FP8_E5M2_KV_CACHE) {
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||||
|
// Vector conversion from V_quant_vec to V_vec.
|
||||||
|
v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
|
||||||
|
#else
|
||||||
|
assert(false);
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
||||||
|
}
|
||||||
if (block_idx == num_context_blocks - 1) {
|
if (block_idx == num_context_blocks - 1) {
|
||||||
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
||||||
// we should explicitly zero out the values since they may contain NaNs.
|
// we should explicitly zero out the values since they may contain NaNs.
|
||||||
@@ -320,7 +368,7 @@ __device__ void paged_attention_kernel(
|
|||||||
float acc = accs[i];
|
float acc = accs[i];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
||||||
acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
|
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
|
||||||
}
|
}
|
||||||
accs[i] = acc;
|
accs[i] = acc;
|
||||||
}
|
}
|
||||||
@@ -379,15 +427,17 @@ __device__ void paged_attention_kernel(
|
|||||||
// Grid: (num_heads, num_seqs, 1).
|
// Grid: (num_heads, num_seqs, 1).
|
||||||
template<
|
template<
|
||||||
typename scalar_t,
|
typename scalar_t,
|
||||||
|
typename cache_t,
|
||||||
int HEAD_SIZE,
|
int HEAD_SIZE,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
int NUM_THREADS>
|
int NUM_THREADS,
|
||||||
|
bool IS_FP8_E5M2_KV_CACHE>
|
||||||
__global__ void paged_attention_v1_kernel(
|
__global__ void paged_attention_v1_kernel(
|
||||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||||
const int* __restrict__ head_mapping, // [num_heads]
|
const int num_kv_heads, // [num_heads]
|
||||||
const float scale,
|
const float scale,
|
||||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
const int* __restrict__ context_lens, // [num_seqs]
|
const int* __restrict__ context_lens, // [num_seqs]
|
||||||
@@ -396,27 +446,29 @@ __global__ void paged_attention_v1_kernel(
|
|||||||
const int q_stride,
|
const int q_stride,
|
||||||
const int kv_block_stride,
|
const int kv_block_stride,
|
||||||
const int kv_head_stride) {
|
const int kv_head_stride) {
|
||||||
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
|
||||||
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
||||||
out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens,
|
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
|
||||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
|
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||||
template<
|
template<
|
||||||
typename scalar_t,
|
typename scalar_t,
|
||||||
|
typename cache_t,
|
||||||
int HEAD_SIZE,
|
int HEAD_SIZE,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
int NUM_THREADS,
|
int NUM_THREADS,
|
||||||
|
bool IS_FP8_E5M2_KV_CACHE,
|
||||||
int PARTITION_SIZE>
|
int PARTITION_SIZE>
|
||||||
__global__ void paged_attention_v2_kernel(
|
__global__ void paged_attention_v2_kernel(
|
||||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||||
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||||
const int* __restrict__ head_mapping, // [num_heads]
|
const int num_kv_heads, // [num_heads]
|
||||||
const float scale,
|
const float scale,
|
||||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
const int* __restrict__ context_lens, // [num_seqs]
|
const int* __restrict__ context_lens, // [num_seqs]
|
||||||
@@ -425,8 +477,8 @@ __global__ void paged_attention_v2_kernel(
|
|||||||
const int q_stride,
|
const int q_stride,
|
||||||
const int kv_block_stride,
|
const int kv_block_stride,
|
||||||
const int kv_head_stride) {
|
const int kv_head_stride) {
|
||||||
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>(
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
|
||||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale,
|
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
||||||
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
|
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
|
||||||
q_stride, kv_block_stride, kv_head_stride);
|
q_stride, kv_block_stride, kv_head_stride);
|
||||||
}
|
}
|
||||||
@@ -486,7 +538,7 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
// Reduce within the warp.
|
// Reduce within the warp.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||||
max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
|
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
||||||
}
|
}
|
||||||
if (lane == 0) {
|
if (lane == 0) {
|
||||||
red_smem[warp_idx] = max_logit;
|
red_smem[warp_idx] = max_logit;
|
||||||
@@ -496,10 +548,10 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||||
max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
|
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
||||||
}
|
}
|
||||||
// Broadcast the max value to all threads.
|
// Broadcast the max value to all threads.
|
||||||
max_logit = __shfl_sync(uint32_t(-1), max_logit, 0);
|
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
|
||||||
|
|
||||||
// Load rescaled exp sums to shared memory.
|
// Load rescaled exp sums to shared memory.
|
||||||
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
||||||
@@ -533,16 +585,16 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
||||||
cudaFuncSetAttribute( \
|
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
||||||
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
|
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||||
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
|
IS_FP8_E5M2_KV_CACHE>), shared_mem_size); \
|
||||||
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||||
<<<grid, block, shared_mem_size, stream>>>( \
|
IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
|
||||||
out_ptr, \
|
out_ptr, \
|
||||||
query_ptr, \
|
query_ptr, \
|
||||||
key_cache_ptr, \
|
key_cache_ptr, \
|
||||||
value_cache_ptr, \
|
value_cache_ptr, \
|
||||||
head_mapping_ptr, \
|
num_kv_heads, \
|
||||||
scale, \
|
scale, \
|
||||||
block_tables_ptr, \
|
block_tables_ptr, \
|
||||||
context_lens_ptr, \
|
context_lens_ptr, \
|
||||||
@@ -555,14 +607,16 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
// TODO(woosuk): Tune NUM_THREADS.
|
// TODO(woosuk): Tune NUM_THREADS.
|
||||||
template<
|
template<
|
||||||
typename T,
|
typename T,
|
||||||
|
typename CACHE_T,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
|
bool IS_FP8_E5M2_KV_CACHE,
|
||||||
int NUM_THREADS = 128>
|
int NUM_THREADS = 128>
|
||||||
void paged_attention_v1_launcher(
|
void paged_attention_v1_launcher(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out,
|
||||||
torch::Tensor& query,
|
torch::Tensor& query,
|
||||||
torch::Tensor& key_cache,
|
torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache,
|
torch::Tensor& value_cache,
|
||||||
torch::Tensor& head_mapping,
|
int num_kv_heads,
|
||||||
float scale,
|
float scale,
|
||||||
torch::Tensor& block_tables,
|
torch::Tensor& block_tables,
|
||||||
torch::Tensor& context_lens,
|
torch::Tensor& context_lens,
|
||||||
@@ -586,9 +640,8 @@ void paged_attention_v1_launcher(
|
|||||||
|
|
||||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
||||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
||||||
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
|
|
||||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||||
|
|
||||||
@@ -602,6 +655,7 @@ void paged_attention_v1_launcher(
|
|||||||
|
|
||||||
dim3 grid(num_heads, num_seqs, 1);
|
dim3 grid(num_heads, num_seqs, 1);
|
||||||
dim3 block(NUM_THREADS);
|
dim3 block(NUM_THREADS);
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
switch (head_size) {
|
switch (head_size) {
|
||||||
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
||||||
@@ -631,35 +685,35 @@ void paged_attention_v1_launcher(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \
|
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
|
||||||
paged_attention_v1_launcher<T, BLOCK_SIZE>( \
|
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
|
||||||
out, \
|
out, \
|
||||||
query, \
|
query, \
|
||||||
key_cache, \
|
key_cache, \
|
||||||
value_cache, \
|
value_cache, \
|
||||||
head_mapping, \
|
num_kv_heads, \
|
||||||
scale, \
|
scale, \
|
||||||
block_tables, \
|
block_tables, \
|
||||||
context_lens, \
|
context_lens, \
|
||||||
max_context_len, \
|
max_context_len, \
|
||||||
alibi_slopes);
|
alibi_slopes);
|
||||||
|
|
||||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||||
// 1, 2, 4, 64, 128, 256.
|
// 1, 2, 4, 64, 128, 256.
|
||||||
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \
|
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
||||||
switch (block_size) { \
|
switch (block_size) { \
|
||||||
case 8: \
|
case 8: \
|
||||||
CALL_V1_LAUNCHER(T, 8); \
|
CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
case 16: \
|
case 16: \
|
||||||
CALL_V1_LAUNCHER(T, 16); \
|
CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
case 32: \
|
case 32: \
|
||||||
CALL_V1_LAUNCHER(T, 32); \
|
CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
default: \
|
default: \
|
||||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||||
break; \
|
break; \
|
||||||
}
|
}
|
||||||
|
|
||||||
void paged_attention_v1(
|
void paged_attention_v1(
|
||||||
@@ -667,26 +721,42 @@ void paged_attention_v1(
|
|||||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
torch::Tensor& head_mapping, // [num_heads]
|
int num_kv_heads, // [num_heads]
|
||||||
float scale,
|
float scale,
|
||||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
torch::Tensor& context_lens, // [num_seqs]
|
torch::Tensor& context_lens, // [num_seqs]
|
||||||
int block_size,
|
int block_size,
|
||||||
int max_context_len,
|
int max_context_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
if (query.dtype() == at::ScalarType::Float) {
|
const std::string& kv_cache_dtype) {
|
||||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float);
|
if (kv_cache_dtype == "auto") {
|
||||||
} else if (query.dtype() == at::ScalarType::Half) {
|
if (query.dtype() == at::ScalarType::Float) {
|
||||||
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t);
|
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
|
||||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
} else if (query.dtype() == at::ScalarType::Half) {
|
||||||
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
|
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
|
||||||
|
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||||
|
}
|
||||||
|
} else if (kv_cache_dtype == "fp8_e5m2") {
|
||||||
|
if (query.dtype() == at::ScalarType::Float) {
|
||||||
|
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
||||||
|
} else if (query.dtype() == at::ScalarType::Half) {
|
||||||
|
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
|
||||||
|
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
||||||
vllm::paged_attention_v2_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||||
|
IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE> \
|
||||||
<<<grid, block, shared_mem_size, stream>>>( \
|
<<<grid, block, shared_mem_size, stream>>>( \
|
||||||
exp_sums_ptr, \
|
exp_sums_ptr, \
|
||||||
max_logits_ptr, \
|
max_logits_ptr, \
|
||||||
@@ -694,7 +764,7 @@ void paged_attention_v1(
|
|||||||
query_ptr, \
|
query_ptr, \
|
||||||
key_cache_ptr, \
|
key_cache_ptr, \
|
||||||
value_cache_ptr, \
|
value_cache_ptr, \
|
||||||
head_mapping_ptr, \
|
num_kv_heads, \
|
||||||
scale, \
|
scale, \
|
||||||
block_tables_ptr, \
|
block_tables_ptr, \
|
||||||
context_lens_ptr, \
|
context_lens_ptr, \
|
||||||
@@ -714,7 +784,9 @@ void paged_attention_v1(
|
|||||||
|
|
||||||
template<
|
template<
|
||||||
typename T,
|
typename T,
|
||||||
|
typename CACHE_T,
|
||||||
int BLOCK_SIZE,
|
int BLOCK_SIZE,
|
||||||
|
bool IS_FP8_E5M2_KV_CACHE,
|
||||||
int NUM_THREADS = 128,
|
int NUM_THREADS = 128,
|
||||||
int PARTITION_SIZE = 512>
|
int PARTITION_SIZE = 512>
|
||||||
void paged_attention_v2_launcher(
|
void paged_attention_v2_launcher(
|
||||||
@@ -725,7 +797,7 @@ void paged_attention_v2_launcher(
|
|||||||
torch::Tensor& query,
|
torch::Tensor& query,
|
||||||
torch::Tensor& key_cache,
|
torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache,
|
torch::Tensor& value_cache,
|
||||||
torch::Tensor& head_mapping,
|
int num_kv_heads,
|
||||||
float scale,
|
float scale,
|
||||||
torch::Tensor& block_tables,
|
torch::Tensor& block_tables,
|
||||||
torch::Tensor& context_lens,
|
torch::Tensor& context_lens,
|
||||||
@@ -752,9 +824,8 @@ void paged_attention_v2_launcher(
|
|||||||
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
||||||
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
||||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
||||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
||||||
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
|
|
||||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||||
|
|
||||||
@@ -771,6 +842,7 @@ void paged_attention_v2_launcher(
|
|||||||
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
|
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
|
||||||
|
|
||||||
dim3 block(NUM_THREADS);
|
dim3 block(NUM_THREADS);
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
switch (head_size) {
|
switch (head_size) {
|
||||||
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
||||||
@@ -800,38 +872,38 @@ void paged_attention_v2_launcher(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \
|
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
|
||||||
paged_attention_v2_launcher<T, BLOCK_SIZE>( \
|
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
|
||||||
out, \
|
out, \
|
||||||
exp_sums, \
|
exp_sums, \
|
||||||
max_logits, \
|
max_logits, \
|
||||||
tmp_out, \
|
tmp_out, \
|
||||||
query, \
|
query, \
|
||||||
key_cache, \
|
key_cache, \
|
||||||
value_cache, \
|
value_cache, \
|
||||||
head_mapping, \
|
num_kv_heads, \
|
||||||
scale, \
|
scale, \
|
||||||
block_tables, \
|
block_tables, \
|
||||||
context_lens, \
|
context_lens, \
|
||||||
max_context_len, \
|
max_context_len, \
|
||||||
alibi_slopes);
|
alibi_slopes);
|
||||||
|
|
||||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||||
// 1, 2, 4, 64, 128, 256.
|
// 1, 2, 4, 64, 128, 256.
|
||||||
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \
|
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
||||||
switch (block_size) { \
|
switch (block_size) { \
|
||||||
case 8: \
|
case 8: \
|
||||||
CALL_V2_LAUNCHER(T, 8); \
|
CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
case 16: \
|
case 16: \
|
||||||
CALL_V2_LAUNCHER(T, 16); \
|
CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
case 32: \
|
case 32: \
|
||||||
CALL_V2_LAUNCHER(T, 32); \
|
CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
|
||||||
break; \
|
break; \
|
||||||
default: \
|
default: \
|
||||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||||
break; \
|
break; \
|
||||||
}
|
}
|
||||||
|
|
||||||
void paged_attention_v2(
|
void paged_attention_v2(
|
||||||
@@ -842,21 +914,36 @@ void paged_attention_v2(
|
|||||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
torch::Tensor& head_mapping, // [num_heads]
|
int num_kv_heads, // [num_heads]
|
||||||
float scale,
|
float scale,
|
||||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
torch::Tensor& context_lens, // [num_seqs]
|
torch::Tensor& context_lens, // [num_seqs]
|
||||||
int block_size,
|
int block_size,
|
||||||
int max_context_len,
|
int max_context_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
if (query.dtype() == at::ScalarType::Float) {
|
const std::string& kv_cache_dtype) {
|
||||||
CALL_V2_LAUNCHER_BLOCK_SIZE(float);
|
if (kv_cache_dtype == "auto") {
|
||||||
} else if (query.dtype() == at::ScalarType::Half) {
|
if (query.dtype() == at::ScalarType::Float) {
|
||||||
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t);
|
CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
|
||||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
} else if (query.dtype() == at::ScalarType::Half) {
|
||||||
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
|
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
|
||||||
|
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||||
|
}
|
||||||
|
} else if (kv_cache_dtype == "fp8_e5m2") {
|
||||||
|
if (query.dtype() == at::ScalarType::Float) {
|
||||||
|
CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
||||||
|
} else if (query.dtype() == at::ScalarType::Half) {
|
||||||
|
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
|
||||||
|
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "../cuda_compat.h"
|
||||||
#include "attention_dtypes.h"
|
#include "attention_dtypes.h"
|
||||||
|
|
||||||
#include <float.h>
|
#include <float.h>
|
||||||
@@ -39,7 +40,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
|
|||||||
float qk = sum(qk_vec);
|
float qk = sum(qk_vec);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
|
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||||
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
|
qk += VLLM_SHFL_XOR_SYNC(qk, mask);
|
||||||
}
|
}
|
||||||
return qk;
|
return qk;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,8 +21,17 @@
|
|||||||
#include "attention_generic.cuh"
|
#include "attention_generic.cuh"
|
||||||
#include "dtype_float32.cuh"
|
#include "dtype_float32.cuh"
|
||||||
|
|
||||||
#include <cuda_bf16.h>
|
#ifndef USE_ROCM
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#else
|
||||||
|
#include <hip/hip_bf16.h>
|
||||||
|
#include <hip/hip_fp16.h>
|
||||||
|
|
||||||
|
typedef __hip_bfloat162 __nv_bfloat162;
|
||||||
|
typedef __hip_bfloat16 __nv_bfloat16;
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
@@ -98,7 +107,11 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
|
|||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
assert(false);
|
assert(false);
|
||||||
#else
|
#else
|
||||||
return a + b;
|
#ifndef USE_ROCM
|
||||||
|
return a + b;
|
||||||
|
#else
|
||||||
|
return __hadd(a, b);
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,10 @@
|
|||||||
#include "attention_generic.cuh"
|
#include "attention_generic.cuh"
|
||||||
#include "dtype_float32.cuh"
|
#include "dtype_float32.cuh"
|
||||||
|
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
#include <hip/hip_fp16.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
@@ -63,21 +67,47 @@ struct FloatVec<uint4> {
|
|||||||
|
|
||||||
// Utility functions for type conversions.
|
// Utility functions for type conversions.
|
||||||
inline __device__ uint32_t h0_h0(uint16_t a) {
|
inline __device__ uint32_t h0_h0(uint16_t a) {
|
||||||
|
#ifndef USE_ROCM
|
||||||
uint32_t b;
|
uint32_t b;
|
||||||
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
|
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
|
||||||
return b;
|
return b;
|
||||||
|
#else
|
||||||
|
union {
|
||||||
|
uint32_t u32;
|
||||||
|
uint16_t u16[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u16[0] = a;
|
||||||
|
tmp.u16[1] = a;
|
||||||
|
return tmp.u32;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
inline __device__ float half_to_float(uint16_t h) {
|
inline __device__ float half_to_float(uint16_t h) {
|
||||||
float f;
|
float f;
|
||||||
|
#ifndef USE_ROCM
|
||||||
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
|
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
|
||||||
|
#else
|
||||||
|
asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
|
||||||
|
#endif
|
||||||
return f;
|
return f;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline __device__ float2 half2_to_float2(uint32_t v) {
|
inline __device__ float2 half2_to_float2(uint32_t v) {
|
||||||
|
#ifndef USE_ROCM
|
||||||
uint16_t lo, hi;
|
uint16_t lo, hi;
|
||||||
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
|
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
|
||||||
return make_float2(half_to_float(lo), half_to_float(hi));
|
return make_float2(half_to_float(lo), half_to_float(hi));
|
||||||
|
#else
|
||||||
|
union {
|
||||||
|
uint32_t u32;
|
||||||
|
uint16_t u16[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u32 = v;
|
||||||
|
float2 ret;
|
||||||
|
ret.x = half_to_float(tmp.u16[0]);
|
||||||
|
ret.y = half_to_float(tmp.u16[1]);
|
||||||
|
return ret;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
inline __device__ uint16_t float_to_half(float f) {
|
inline __device__ uint16_t float_to_half(float f) {
|
||||||
@@ -85,7 +115,11 @@ inline __device__ uint16_t float_to_half(float f) {
|
|||||||
uint32_t u32;
|
uint32_t u32;
|
||||||
uint16_t u16[2];
|
uint16_t u16[2];
|
||||||
} tmp;
|
} tmp;
|
||||||
|
#ifndef USE_ROCM
|
||||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
|
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
|
||||||
|
#else
|
||||||
|
asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
|
||||||
|
#endif
|
||||||
return tmp.u16[0];
|
return tmp.u16[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,12 +128,16 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
|
|||||||
uint32_t u32;
|
uint32_t u32;
|
||||||
uint16_t u16[2];
|
uint16_t u16[2];
|
||||||
} tmp;
|
} tmp;
|
||||||
|
#ifndef USE_ROCM
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
|
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
|
||||||
|
#else
|
||||||
|
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
||||||
|
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
||||||
|
#endif
|
||||||
#else
|
#else
|
||||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
tmp.u16[0] = float_to_half(f.x);
|
||||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
tmp.u16[1] = float_to_half(f.y);
|
||||||
#endif
|
#endif
|
||||||
return tmp.u32;
|
return tmp.u32;
|
||||||
}
|
}
|
||||||
@@ -107,13 +145,21 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
|
|||||||
// Vector addition.
|
// Vector addition.
|
||||||
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
|
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
|
||||||
uint16_t c;
|
uint16_t c;
|
||||||
|
#ifndef USE_ROCM
|
||||||
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
||||||
|
#else
|
||||||
|
asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||||
|
#endif
|
||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
|
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
|
||||||
uint32_t c;
|
uint32_t c;
|
||||||
|
#ifndef USE_ROCM
|
||||||
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
||||||
|
#else
|
||||||
|
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||||
|
#endif
|
||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,14 +204,22 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
|
|||||||
template<>
|
template<>
|
||||||
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
|
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
|
||||||
uint16_t c;
|
uint16_t c;
|
||||||
|
#ifndef USE_ROCM
|
||||||
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
||||||
|
#else
|
||||||
|
asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||||
|
#endif
|
||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template<>
|
||||||
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
|
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
|
||||||
uint32_t c;
|
uint32_t c;
|
||||||
|
#ifndef USE_ROCM
|
||||||
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
||||||
|
#else
|
||||||
|
asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||||
|
#endif
|
||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -272,7 +326,11 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
|
|||||||
// Vector fused multiply-add.
|
// Vector fused multiply-add.
|
||||||
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
|
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
|
||||||
uint32_t d;
|
uint32_t d;
|
||||||
|
#ifndef USE_ROCM
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
|
||||||
|
#else
|
||||||
|
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
|
||||||
|
#endif
|
||||||
return d;
|
return d;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
35
csrc/attention/dtype_fp8_e5m2.cuh
Normal file
35
csrc/attention/dtype_fp8_e5m2.cuh
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "attention_generic.cuh"
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
#include <cuda_fp8.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
// fp8 vector types for quantization of kv cache
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct Vec<uint8_t, 1> {
|
||||||
|
using Type = uint8_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct Vec<uint8_t, 2> {
|
||||||
|
using Type = uint16_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct Vec<uint8_t, 4> {
|
||||||
|
using Type = uint32_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct Vec<uint8_t, 8> {
|
||||||
|
using Type = uint2;
|
||||||
|
};
|
||||||
|
#endif // ENABLE_FP8_E5M2
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
#include <map>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
void swap_blocks(
|
|
||||||
torch::Tensor& src,
|
|
||||||
torch::Tensor& dst,
|
|
||||||
const std::map<int64_t, int64_t>& block_mapping);
|
|
||||||
|
|
||||||
void copy_blocks(
|
|
||||||
std::vector<torch::Tensor>& key_caches,
|
|
||||||
std::vector<torch::Tensor>& value_caches,
|
|
||||||
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
|
|
||||||
|
|
||||||
void reshape_and_cache(
|
|
||||||
torch::Tensor& key,
|
|
||||||
torch::Tensor& value,
|
|
||||||
torch::Tensor& key_cache,
|
|
||||||
torch::Tensor& value_cache,
|
|
||||||
torch::Tensor& slot_mapping);
|
|
||||||
|
|
||||||
void gather_cached_kv(
|
|
||||||
torch::Tensor& key,
|
|
||||||
torch::Tensor& value,
|
|
||||||
torch::Tensor& key_cache,
|
|
||||||
torch::Tensor& value_cache,
|
|
||||||
torch::Tensor& slot_mapping);
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def(
|
|
||||||
"swap_blocks",
|
|
||||||
&swap_blocks,
|
|
||||||
"Swap in (out) the cache blocks from src to dst");
|
|
||||||
m.def(
|
|
||||||
"copy_blocks",
|
|
||||||
©_blocks,
|
|
||||||
"Copy the cache blocks from src to dst");
|
|
||||||
m.def(
|
|
||||||
"reshape_and_cache",
|
|
||||||
&reshape_and_cache,
|
|
||||||
"Reshape the key and value tensors and cache them");
|
|
||||||
m.def(
|
|
||||||
"gather_cached_kv",
|
|
||||||
&gather_cached_kv,
|
|
||||||
"Gather key and value from the cache into contiguous QKV tensors");
|
|
||||||
}
|
|
||||||
29
csrc/cache.h
Normal file
29
csrc/cache.h
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
void swap_blocks(
|
||||||
|
torch::Tensor& src,
|
||||||
|
torch::Tensor& dst,
|
||||||
|
const std::map<int64_t, int64_t>& block_mapping);
|
||||||
|
|
||||||
|
void copy_blocks(
|
||||||
|
std::vector<torch::Tensor>& key_caches,
|
||||||
|
std::vector<torch::Tensor>& value_caches,
|
||||||
|
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
|
||||||
|
|
||||||
|
void reshape_and_cache(
|
||||||
|
torch::Tensor& key,
|
||||||
|
torch::Tensor& value,
|
||||||
|
torch::Tensor& key_cache,
|
||||||
|
torch::Tensor& value_cache,
|
||||||
|
torch::Tensor& slot_mapping,
|
||||||
|
const std::string& kv_cache_dtype);
|
||||||
|
|
||||||
|
// Just for unittest
|
||||||
|
void convert_fp8_e5m2(
|
||||||
|
torch::Tensor& src_cache,
|
||||||
|
torch::Tensor& dst_cache);
|
||||||
@@ -1,13 +1,23 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
|
#include "cuda_compat.h"
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
#include <hip/hip_bf16.h>
|
||||||
|
typedef __hip_bfloat16 __nv_bfloat16;
|
||||||
|
#endif
|
||||||
|
|
||||||
void swap_blocks(
|
void swap_blocks(
|
||||||
torch::Tensor& src,
|
torch::Tensor& src,
|
||||||
torch::Tensor& dst,
|
torch::Tensor& dst,
|
||||||
@@ -28,10 +38,11 @@ void swap_blocks(
|
|||||||
TORCH_CHECK(false, "Invalid device combination");
|
TORCH_CHECK(false, "Invalid device combination");
|
||||||
}
|
}
|
||||||
|
|
||||||
void *src_ptr = src.data_ptr();
|
char *src_ptr = static_cast<char*>(src.data_ptr());
|
||||||
void *dst_ptr = dst.data_ptr();
|
char *dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||||
|
|
||||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
||||||
for (const auto& pair : block_mapping) {
|
for (const auto& pair : block_mapping) {
|
||||||
@@ -55,26 +66,26 @@ template<typename scalar_t>
|
|||||||
__global__ void copy_blocks_kernel(
|
__global__ void copy_blocks_kernel(
|
||||||
int64_t* key_cache_ptrs,
|
int64_t* key_cache_ptrs,
|
||||||
int64_t* value_cache_ptrs,
|
int64_t* value_cache_ptrs,
|
||||||
const int* __restrict__ block_mapping,
|
const int64_t* __restrict__ block_mapping,
|
||||||
const int numel_per_block) {
|
const int numel_per_block) {
|
||||||
const int layer_idx = blockIdx.x;
|
const int layer_idx = blockIdx.x;
|
||||||
const int pair_idx = blockIdx.y;
|
const int pair_idx = blockIdx.y;
|
||||||
|
|
||||||
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
|
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
|
||||||
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
||||||
int src_block_number = block_mapping[2 * pair_idx];
|
int64_t src_block_number = block_mapping[2 * pair_idx];
|
||||||
int dst_block_number = block_mapping[2 * pair_idx + 1];
|
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
|
||||||
|
|
||||||
const int src_block_offset = src_block_number * numel_per_block;
|
const int64_t src_block_offset = src_block_number * numel_per_block;
|
||||||
const int dst_block_offset = dst_block_number * numel_per_block;
|
const int64_t dst_block_offset = dst_block_number * numel_per_block;
|
||||||
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
||||||
int src_offset = src_block_offset + i;
|
int64_t src_offset = src_block_offset + i;
|
||||||
int dst_offset = dst_block_offset + i;
|
int64_t dst_offset = dst_block_offset + i;
|
||||||
key_cache[dst_offset] = key_cache[src_offset];
|
key_cache[dst_offset] = key_cache[src_offset];
|
||||||
}
|
}
|
||||||
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
||||||
int src_offset = src_block_offset + i;
|
int64_t src_offset = src_block_offset + i;
|
||||||
int dst_offset = dst_block_offset + i;
|
int64_t dst_offset = dst_block_offset + i;
|
||||||
value_cache[dst_offset] = value_cache[src_offset];
|
value_cache[dst_offset] = value_cache[src_offset];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -102,15 +113,15 @@ void copy_blocks(
|
|||||||
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
||||||
}
|
}
|
||||||
// Create block mapping array.
|
// Create block mapping array.
|
||||||
std::vector<int> block_mapping_vec;
|
std::vector<int64_t> block_mapping_vec;
|
||||||
for (const auto& pair : block_mapping) {
|
for (const auto& pair : block_mapping) {
|
||||||
int src_block_number = pair.first;
|
int64_t src_block_number = pair.first;
|
||||||
for (int dst_block_number : pair.second) {
|
for (int64_t dst_block_number : pair.second) {
|
||||||
block_mapping_vec.push_back(src_block_number);
|
block_mapping_vec.push_back(src_block_number);
|
||||||
block_mapping_vec.push_back(dst_block_number);
|
block_mapping_vec.push_back(dst_block_number);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int* block_mapping_array = block_mapping_vec.data();
|
int64_t* block_mapping_array = block_mapping_vec.data();
|
||||||
int num_pairs = block_mapping_vec.size() / 2;
|
int num_pairs = block_mapping_vec.size() / 2;
|
||||||
|
|
||||||
// Move the data structures to the GPU.
|
// Move the data structures to the GPU.
|
||||||
@@ -120,75 +131,107 @@ void copy_blocks(
|
|||||||
torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
|
torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
|
||||||
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
||||||
torch::Tensor block_mapping_tensor = torch::from_blob(
|
torch::Tensor block_mapping_tensor = torch::from_blob(
|
||||||
block_mapping_array, {2 * num_pairs}, torch::kInt).to(cache_device);
|
block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
|
||||||
|
|
||||||
// Launch the kernel.
|
// Launch the kernel.
|
||||||
const int numel_per_block = key_caches[0][0].numel();
|
const int numel_per_block = key_caches[0][0].numel();
|
||||||
dim3 grid(num_layers, num_pairs);
|
dim3 grid(num_layers, num_pairs);
|
||||||
dim3 block(std::min(1024, numel_per_block));
|
dim3 block(std::min(1024, numel_per_block));
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
||||||
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||||
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||||
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||||
block_mapping_tensor.data_ptr<int>(),
|
block_mapping_tensor.data_ptr<int64_t>(),
|
||||||
numel_per_block);
|
numel_per_block);
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template<typename scalar_t>
|
template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
|
||||||
__global__ void reshape_and_cache_kernel(
|
__global__ void reshape_and_cache_kernel(
|
||||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||||
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
const int* __restrict__ slot_mapping, // [num_tokens]
|
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||||
const int key_stride,
|
const int key_stride,
|
||||||
const int value_stride,
|
const int value_stride,
|
||||||
const int num_heads,
|
const int num_heads,
|
||||||
const int head_size,
|
const int head_size,
|
||||||
const int block_size,
|
const int block_size,
|
||||||
const int x) {
|
const int x) {
|
||||||
const int token_idx = blockIdx.x;
|
const int64_t token_idx = blockIdx.x;
|
||||||
const int slot_idx = slot_mapping[token_idx];
|
const int64_t slot_idx = slot_mapping[token_idx];
|
||||||
const int block_idx = slot_idx / block_size;
|
if (slot_idx < 0) {
|
||||||
const int block_offset = slot_idx % block_size;
|
// Padding token that should be ignored.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t block_idx = slot_idx / block_size;
|
||||||
|
const int64_t block_offset = slot_idx % block_size;
|
||||||
|
|
||||||
const int n = num_heads * head_size;
|
const int n = num_heads * head_size;
|
||||||
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
||||||
const int src_key_idx = token_idx * key_stride + i;
|
const int64_t src_key_idx = token_idx * key_stride + i;
|
||||||
const int src_value_idx = token_idx * value_stride + i;
|
const int64_t src_value_idx = token_idx * value_stride + i;
|
||||||
|
|
||||||
const int head_idx = i / head_size;
|
const int head_idx = i / head_size;
|
||||||
const int head_offset = i % head_size;
|
const int head_offset = i % head_size;
|
||||||
const int x_idx = head_offset / x;
|
const int x_idx = head_offset / x;
|
||||||
const int x_offset = head_offset % x;
|
const int x_offset = head_offset % x;
|
||||||
|
|
||||||
const int tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
||||||
+ head_idx * (head_size / x) * block_size * x
|
+ head_idx * (head_size / x) * block_size * x
|
||||||
+ x_idx * block_size * x
|
+ x_idx * block_size * x
|
||||||
+ block_offset * x
|
+ block_offset * x
|
||||||
+ x_offset;
|
+ x_offset;
|
||||||
const int tgt_value_idx = block_idx * num_heads * head_size * block_size
|
const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
|
||||||
+ head_idx * head_size * block_size
|
+ head_idx * head_size * block_size
|
||||||
+ head_offset * block_size
|
+ head_offset * block_size
|
||||||
+ block_offset;
|
+ block_offset;
|
||||||
key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]);
|
scalar_t tgt_key = key[src_key_idx];
|
||||||
value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]);
|
scalar_t tgt_value = value[src_value_idx];
|
||||||
|
if constexpr (is_fp8_e5m2_kv_cache) {
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
|
||||||
|
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
|
||||||
|
#else
|
||||||
|
assert(false);
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
key_cache[tgt_key_idx] = tgt_key;
|
||||||
|
value_cache[tgt_value_idx] = tgt_value;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
|
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
||||||
|
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
|
||||||
|
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||||
|
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||||
|
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||||
|
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||||
|
slot_mapping.data_ptr<int64_t>(), \
|
||||||
|
key_stride, \
|
||||||
|
value_stride, \
|
||||||
|
num_heads, \
|
||||||
|
head_size, \
|
||||||
|
block_size, \
|
||||||
|
x);
|
||||||
|
|
||||||
void reshape_and_cache(
|
void reshape_and_cache(
|
||||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
torch::Tensor& slot_mapping) // [num_tokens]
|
torch::Tensor& slot_mapping, // [num_tokens]
|
||||||
|
const std::string& kv_cache_dtype)
|
||||||
{
|
{
|
||||||
int num_tokens = key.size(0);
|
int num_tokens = key.size(0);
|
||||||
int num_heads = key.size(1);
|
int num_heads = key.size(1);
|
||||||
@@ -201,182 +244,77 @@ void reshape_and_cache(
|
|||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(num_heads * head_size, 512));
|
dim3 block(std::min(num_heads * head_size, 512));
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
if (kv_cache_dtype == "auto") {
|
||||||
key.scalar_type(),
|
if (key.dtype() == at::ScalarType::Float) {
|
||||||
"reshape_and_cache_kernel",
|
CALL_RESHAPE_AND_CACHE(float, float, false);
|
||||||
[&] {
|
} else if (key.dtype() == at::ScalarType::Half) {
|
||||||
vllm::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
|
||||||
key.data_ptr<scalar_t>(),
|
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
||||||
value.data_ptr<scalar_t>(),
|
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
|
||||||
key_cache.data_ptr<scalar_t>(),
|
}
|
||||||
value_cache.data_ptr<scalar_t>(),
|
} else if (kv_cache_dtype == "fp8_e5m2") {
|
||||||
slot_mapping.data_ptr<int>(),
|
if (key.dtype() == at::ScalarType::Float) {
|
||||||
key_stride,
|
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
|
||||||
value_stride,
|
} else if (key.dtype() == at::ScalarType::Half) {
|
||||||
num_heads,
|
CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
|
||||||
head_size,
|
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
||||||
block_size,
|
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
|
||||||
x);
|
}
|
||||||
});
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
// Grid: (num_blocks, block_size).
|
template<typename Tout, typename Tin>
|
||||||
template<typename scalar_t>
|
__global__ void convert_fp8_e5m2_kernel(
|
||||||
__global__ void gather_cached_kv_kernel(
|
const Tin* __restrict__ src_cache,
|
||||||
scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size]
|
Tout* __restrict__ dst_cache,
|
||||||
scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size]
|
const int64_t block_stride) {
|
||||||
const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
const int64_t block_idx = blockIdx.x;
|
||||||
const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
||||||
const int* __restrict__ slot_mapping, // [num_tokens]
|
int64_t idx = block_idx * block_stride + i;
|
||||||
const int key_stride,
|
#ifdef ENABLE_FP8_E5M2
|
||||||
const int value_stride,
|
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
|
||||||
const int num_heads,
|
#else
|
||||||
const int head_size,
|
assert(false);
|
||||||
const int block_size,
|
#endif
|
||||||
const int x) {
|
}
|
||||||
const int token_idx = blockIdx.x;
|
|
||||||
const int slot_idx = slot_mapping[token_idx];
|
|
||||||
const int block_idx = slot_idx / block_size;
|
|
||||||
const int block_offset = slot_idx % block_size;
|
|
||||||
|
|
||||||
const int num_tokens = num_heads * head_size;
|
|
||||||
for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
|
|
||||||
const int tgt_key_idx = token_idx * key_stride + i;
|
|
||||||
const int tgt_value_idx = token_idx * value_stride + i;
|
|
||||||
|
|
||||||
const int head_idx = i / head_size;
|
|
||||||
const int head_offset = i % head_size;
|
|
||||||
const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension
|
|
||||||
const int x_offset = head_offset % x;
|
|
||||||
|
|
||||||
const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
|
||||||
+ head_idx * (head_size / x) * block_size * x
|
|
||||||
+ x_idx * block_size * x
|
|
||||||
+ block_offset * x
|
|
||||||
+ x_offset;
|
|
||||||
const int src_value_idx = block_idx * num_heads * head_size * block_size
|
|
||||||
+ head_idx * head_size * block_size
|
|
||||||
+ head_offset * block_size
|
|
||||||
+ block_offset;
|
|
||||||
|
|
||||||
key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]);
|
|
||||||
value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
|
||||||
__global__ void gather_cached_kv_kernel_optimized(
|
|
||||||
scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size]
|
|
||||||
scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size]
|
|
||||||
const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
|
||||||
const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
|
||||||
const int *__restrict__ slot_mapping, // [num_tokens]
|
|
||||||
const int key_stride,
|
|
||||||
const int value_stride,
|
|
||||||
const int num_heads,
|
|
||||||
const int head_size,
|
|
||||||
const int block_size,
|
|
||||||
const int x)
|
|
||||||
{
|
|
||||||
const int token_idx = blockIdx.x;
|
|
||||||
const int slot_idx = slot_mapping[token_idx];
|
|
||||||
const int block_idx = slot_idx / block_size;
|
|
||||||
const int block_offset = slot_idx % block_size;
|
|
||||||
|
|
||||||
const int dim = num_heads * head_size;
|
|
||||||
assert(dim % 4 == 0); // this is true for known use cases
|
|
||||||
const int unroll_factor = 4;
|
|
||||||
const int unrolled_dim = dim / unroll_factor;
|
|
||||||
|
|
||||||
for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x)
|
|
||||||
{
|
|
||||||
int tgt_key_indices[unroll_factor];
|
|
||||||
int tgt_value_indices[unroll_factor];
|
|
||||||
int src_key_indices[unroll_factor];
|
|
||||||
int src_value_indices[unroll_factor];
|
|
||||||
scalar_t keys_to_store[unroll_factor];
|
|
||||||
scalar_t values_to_store[unroll_factor];
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < unroll_factor; ++j)
|
|
||||||
{
|
|
||||||
int index = i + j * unrolled_dim;
|
|
||||||
|
|
||||||
const int tgt_key_idx = token_idx * key_stride + index;
|
|
||||||
const int tgt_value_idx = token_idx * value_stride + index;
|
|
||||||
|
|
||||||
const int head_idx = index / head_size;
|
|
||||||
const int head_offset = index % head_size;
|
|
||||||
const int x_idx = head_offset / x;
|
|
||||||
const int x_offset = head_offset % x;
|
|
||||||
|
|
||||||
const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
|
||||||
+ head_idx * (head_size / x) * block_size * x
|
|
||||||
+ x_idx * block_size * x
|
|
||||||
+ block_offset * x
|
|
||||||
+ x_offset;
|
|
||||||
const int src_value_idx = block_idx * num_heads * head_size * block_size
|
|
||||||
+ head_idx * head_size * block_size
|
|
||||||
+ head_offset * block_size
|
|
||||||
+ block_offset;
|
|
||||||
|
|
||||||
tgt_key_indices[j] = tgt_key_idx;
|
|
||||||
tgt_value_indices[j] = tgt_value_idx;
|
|
||||||
src_key_indices[j] = src_key_idx;
|
|
||||||
src_value_indices[j] = src_value_idx;
|
|
||||||
|
|
||||||
keys_to_store[j] = __ldg(&key_cache[src_key_idx]);
|
|
||||||
values_to_store[j] = __ldg(&value_cache[src_value_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < unroll_factor; ++j)
|
|
||||||
{
|
|
||||||
key[tgt_key_indices[j]] = keys_to_store[j];
|
|
||||||
value[tgt_value_indices[j]] = values_to_store[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void gather_cached_kv(
|
#define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
|
||||||
torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
|
vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
|
||||||
torch::Tensor& value, // [out] [num_tokens, num_heads, head_size]
|
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
||||||
torch::Tensor& key_cache, // [in] [num_blocks, num_heads, head_size/x, block_size, x]
|
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
|
||||||
torch::Tensor& value_cache, // [in] [num_blocks, num_heads, head_size, block_size]
|
block_stride);
|
||||||
torch::Tensor& slot_mapping) // [in] [num_tokens]
|
|
||||||
|
void convert_fp8_e5m2(
|
||||||
|
torch::Tensor& src_cache,
|
||||||
|
torch::Tensor& dst_cache)
|
||||||
{
|
{
|
||||||
int num_tokens = key.size(0);
|
int64_t num_blocks = src_cache.size(0);
|
||||||
int num_heads = key.size(1);
|
int64_t block_stride = src_cache.stride(0);
|
||||||
int head_size = key.size(2);
|
|
||||||
int block_size = key_cache.size(3);
|
|
||||||
int x = key_cache.size(4);
|
|
||||||
|
|
||||||
int key_stride = key.stride(0);
|
dim3 grid(num_blocks);
|
||||||
int value_stride = value.stride(0);
|
dim3 block(std::min(block_stride, int64_t(512)));
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
|
||||||
dim3 block(std::min(num_heads * head_size, 512));
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
|
||||||
key.scalar_type(),
|
if (src_cache.dtype() == at::ScalarType::Float) {
|
||||||
"gather_cached_kv_kernel_optimized",
|
CALL_CONVERT_FP8_E5M2(uint8_t, float);
|
||||||
[&] {
|
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
||||||
vllm::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
|
CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
|
||||||
key.data_ptr<scalar_t>(),
|
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
||||||
value.data_ptr<scalar_t>(),
|
CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
|
||||||
key_cache.data_ptr<scalar_t>(),
|
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
||||||
value_cache.data_ptr<scalar_t>(),
|
CALL_CONVERT_FP8_E5M2(float, uint8_t);
|
||||||
slot_mapping.data_ptr<int>(),
|
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
||||||
key_stride,
|
CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
|
||||||
value_stride,
|
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
||||||
num_heads,
|
CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
|
||||||
head_size,
|
}
|
||||||
block_size,
|
|
||||||
x);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|||||||
28
csrc/cuda_compat.h
Normal file
28
csrc/cuda_compat.h
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#define VLLM_LDG(arg) __ldg(arg)
|
||||||
|
#else
|
||||||
|
#define VLLM_LDG(arg) *(arg)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
||||||
|
#else
|
||||||
|
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
|
||||||
|
#else
|
||||||
|
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||||
|
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||||
|
#else
|
||||||
|
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||||
|
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||||
|
#endif
|
||||||
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
int get_device_attribute(
|
|
||||||
int attribute,
|
|
||||||
int device_id);
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def(
|
|
||||||
"get_device_attribute",
|
|
||||||
&get_device_attribute,
|
|
||||||
"Gets the specified device attribute.");
|
|
||||||
}
|
|
||||||
|
|
||||||
10
csrc/cuda_utils.h
Normal file
10
csrc/cuda_utils.h
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
int get_device_attribute(
|
||||||
|
int attribute,
|
||||||
|
int device_id);
|
||||||
|
|
||||||
|
int get_max_shared_memory_per_block_device_attribute(
|
||||||
|
int device_id);
|
||||||
@@ -1,3 +1,7 @@
|
|||||||
|
#ifdef USE_ROCM
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
#include <hip/hip_runtime_api.h>
|
||||||
|
#endif
|
||||||
int get_device_attribute(
|
int get_device_attribute(
|
||||||
int attribute,
|
int attribute,
|
||||||
int device_id)
|
int device_id)
|
||||||
@@ -12,3 +16,20 @@ int get_device_attribute(
|
|||||||
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
|
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
int get_max_shared_memory_per_block_device_attribute(
|
||||||
|
int device_id)
|
||||||
|
{
|
||||||
|
int attribute;
|
||||||
|
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
||||||
|
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
||||||
|
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
|
||||||
|
#else
|
||||||
|
attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return get_device_attribute(attribute, device_id);
|
||||||
|
}
|
||||||
|
|||||||
148
csrc/custom_all_reduce.cu
Normal file
148
csrc/custom_all_reduce.cu
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
#include <ATen/cuda/Exceptions.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <c10/cuda/CUDAStream.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#include "custom_all_reduce.cuh"
|
||||||
|
|
||||||
|
// fake pointer type
|
||||||
|
using fptr_t = uint64_t;
|
||||||
|
static_assert(sizeof(void *) == sizeof(fptr_t));
|
||||||
|
|
||||||
|
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
||||||
|
const std::vector<std::string> &handles,
|
||||||
|
const std::vector<int64_t> &offsets, int rank,
|
||||||
|
bool full_nvlink) {
|
||||||
|
int world_size = offsets.size();
|
||||||
|
if (world_size > 8)
|
||||||
|
throw std::invalid_argument("world size > 8 is not supported");
|
||||||
|
if (world_size % 2 != 0)
|
||||||
|
throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||||
|
if (world_size != handles.size())
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"handles length should equal to offsets length");
|
||||||
|
if (rank < 0 || rank >= world_size)
|
||||||
|
throw std::invalid_argument("invalid rank passed in");
|
||||||
|
|
||||||
|
cudaIpcMemHandle_t ipc_handles[8];
|
||||||
|
for (int i = 0; i < world_size; i++) {
|
||||||
|
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
|
||||||
|
}
|
||||||
|
return (fptr_t) new vllm::CustomAllreduce(
|
||||||
|
reinterpret_cast<vllm::Metadata *>(meta.data_ptr()), rank_data.data_ptr(),
|
||||||
|
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
|
||||||
|
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
|
||||||
|
* because it allows transpose of contiguous slice (i.e. slicing the first
|
||||||
|
* dimension). Currently, we require this because stride information is not
|
||||||
|
* passed into the kernels and we treat input tensors as flat.
|
||||||
|
*
|
||||||
|
* Examples
|
||||||
|
* A = torch.zeros(3, 3, 3)
|
||||||
|
* 1. A: OK
|
||||||
|
* 2. A[1:]: OK
|
||||||
|
* 3. A.permute(2, 0, 1): OK
|
||||||
|
* 4. A[1:].permute(2, 0, 1): OK
|
||||||
|
* 5. A[None].expand(2, -1, -1, -1): Not OK
|
||||||
|
* 6. A[:, 1:, 1:]: Not OK
|
||||||
|
*/
|
||||||
|
bool _is_weak_contiguous(torch::Tensor &t) {
|
||||||
|
return t.is_contiguous() ||
|
||||||
|
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
|
||||||
|
t.numel() * t.element_size());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
||||||
|
bool full_nvlink) {
|
||||||
|
auto inp_size = inp.numel() * inp.element_size();
|
||||||
|
// custom allreduce requires input byte size to be multiples of 16
|
||||||
|
if (inp_size % 16 != 0) return false;
|
||||||
|
if (!_is_weak_contiguous(inp)) return false;
|
||||||
|
if (world_size == 2 || full_nvlink) return inp_size <= max_size;
|
||||||
|
// 4 PCIE GPUs use 2 stage allreduce, and is only faster than NCCL when size
|
||||||
|
// <= 512k
|
||||||
|
return world_size <= 4 && inp_size <= 512 * 1024;
|
||||||
|
}
|
||||||
|
|
||||||
|
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||||
|
TORCH_CHECK(_is_weak_contiguous(out));
|
||||||
|
switch (out.scalar_type()) {
|
||||||
|
case at::ScalarType::Float: {
|
||||||
|
fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()),
|
||||||
|
reinterpret_cast<float *>(out.data_ptr()),
|
||||||
|
out.numel());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case at::ScalarType::Half: {
|
||||||
|
fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()),
|
||||||
|
reinterpret_cast<half *>(out.data_ptr()),
|
||||||
|
out.numel());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||||
|
case at::ScalarType::BFloat16: {
|
||||||
|
fa->allreduce<nv_bfloat16>(
|
||||||
|
stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()),
|
||||||
|
reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"custom allreduce only supports float32, float16 and bfloat16");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||||
|
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||||
|
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||||
|
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||||
|
_all_reduce(_fa, inp, out, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
||||||
|
torch::Tensor &out) {
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||||
|
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||||
|
|
||||||
|
auto input_size = inp.numel() * inp.element_size();
|
||||||
|
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||||
|
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||||
|
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
|
||||||
|
"registered buffer is too small to contain the input");
|
||||||
|
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
|
||||||
|
input_size, cudaMemcpyDeviceToDevice, stream));
|
||||||
|
_all_reduce(_fa, reg_buffer, out, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
void dispose(fptr_t _fa) {
|
||||||
|
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||||
|
delete fa;
|
||||||
|
}
|
||||||
|
|
||||||
|
int meta_size() { return sizeof(vllm::Metadata); }
|
||||||
|
|
||||||
|
void register_buffer(fptr_t _fa, torch::Tensor &t,
|
||||||
|
const std::vector<std::string> &handles,
|
||||||
|
const std::vector<int64_t> &offsets) {
|
||||||
|
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||||
|
fa->register_buffer(handles, offsets, t.data_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||||
|
fptr_t _fa) {
|
||||||
|
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||||
|
return fa->get_graph_buffer_ipc_meta();
|
||||||
|
}
|
||||||
|
|
||||||
|
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
|
||||||
|
const std::vector<std::vector<int64_t>> &offsets) {
|
||||||
|
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||||
|
fa->register_graph_buffers(handles, offsets);
|
||||||
|
}
|
||||||
562
csrc/custom_all_reduce.cuh
Normal file
562
csrc/custom_all_reduce.cuh
Normal file
@@ -0,0 +1,562 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <limits>
|
||||||
|
#include <map>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#define CUDACHECK(cmd) \
|
||||||
|
do { \
|
||||||
|
cudaError_t e = cmd; \
|
||||||
|
if (e != cudaSuccess) { \
|
||||||
|
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
|
||||||
|
cudaGetErrorString(e)); \
|
||||||
|
exit(EXIT_FAILURE); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
struct Signal {
|
||||||
|
alignas(64) union {
|
||||||
|
uint64_t flag;
|
||||||
|
unsigned char data[8];
|
||||||
|
} start;
|
||||||
|
alignas(64) union {
|
||||||
|
uint64_t flag;
|
||||||
|
unsigned char data[8];
|
||||||
|
} end;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Metadata {
|
||||||
|
alignas(128) Signal sg;
|
||||||
|
alignas(128) int counter;
|
||||||
|
};
|
||||||
|
static_assert(offsetof(Metadata, counter) == 128);
|
||||||
|
static_assert(sizeof(Metadata) == 256);
|
||||||
|
|
||||||
|
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };
|
||||||
|
|
||||||
|
struct RankSignals {
|
||||||
|
volatile Signal *signals[8];
|
||||||
|
};
|
||||||
|
|
||||||
|
// like std::array, but aligned
|
||||||
|
template <typename T, int sz>
|
||||||
|
struct __align__(alignof(T) * sz) array_t {
|
||||||
|
T data[sz];
|
||||||
|
using type = T;
|
||||||
|
static constexpr int size = sz;
|
||||||
|
};
|
||||||
|
|
||||||
|
// use packed type to maximize memory efficiency
|
||||||
|
// goal: generate ld.128 and st.128 instructions
|
||||||
|
template <typename T>
|
||||||
|
struct packed_t {
|
||||||
|
// the (P)acked type for load/store
|
||||||
|
using P = array_t<T, 16 / sizeof(T)>;
|
||||||
|
// the (A)ccumulator type for reduction
|
||||||
|
using A = array_t<float, 16 / sizeof(T)>;
|
||||||
|
};
|
||||||
|
|
||||||
|
#define DINLINE __device__ __forceinline__
|
||||||
|
|
||||||
|
// scalar cast functions
|
||||||
|
DINLINE float upcast_s(half val) { return __half2float(val); }
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
DINLINE T downcast_s(float val);
|
||||||
|
template <>
|
||||||
|
DINLINE half downcast_s(float val) {
|
||||||
|
return __float2half(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
// scalar add functions
|
||||||
|
// for some reason when compiling with Pytorch, the + operator for half and
|
||||||
|
// bfloat is disabled so we call the intrinsics directly
|
||||||
|
DINLINE half &assign_add(half &a, half b) {
|
||||||
|
a = __hadd(a, b);
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
DINLINE float &assign_add(float &a, float b) { return a += b; }
|
||||||
|
|
||||||
|
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||||
|
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
|
||||||
|
template <>
|
||||||
|
DINLINE nv_bfloat16 downcast_s(float val) {
|
||||||
|
return __float2bfloat16(val);
|
||||||
|
}
|
||||||
|
DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) {
|
||||||
|
a = __hadd(a, b);
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <typename T, int N>
|
||||||
|
DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
assign_add(a.data[i], b.data[i]);
|
||||||
|
}
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int N>
|
||||||
|
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
|
||||||
|
if constexpr (std::is_same<T, float>::value) {
|
||||||
|
return val;
|
||||||
|
} else {
|
||||||
|
array_t<float, N> out;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
out.data[i] = upcast_s(val.data[i]);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename O>
|
||||||
|
DINLINE O downcast(array_t<float, O::size> val) {
|
||||||
|
if constexpr (std::is_same<typename O::type, float>::value) {
|
||||||
|
return val;
|
||||||
|
} else {
|
||||||
|
O out;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < O::size; i++) {
|
||||||
|
out.data[i] = downcast_s<typename O::type>(val.data[i]);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute flag at compile time
|
||||||
|
__host__ __device__ constexpr uint64_t compute_flag(int ngpus) {
|
||||||
|
auto m = std::numeric_limits<uint64_t>::max();
|
||||||
|
return m >> ((8 - ngpus) * 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int ngpus>
|
||||||
|
DINLINE void start_sync(const RankSignals &sg, volatile Metadata *meta,
|
||||||
|
int rank) {
|
||||||
|
constexpr auto FLAG = compute_flag(ngpus);
|
||||||
|
if (blockIdx.x == 0) {
|
||||||
|
if (threadIdx.x < ngpus)
|
||||||
|
// simultaneously write to the corresponding byte to all other ranks.
|
||||||
|
// Latency = 1 p2p write
|
||||||
|
sg.signals[threadIdx.x]->start.data[rank] = 255;
|
||||||
|
else if (threadIdx.x == 32)
|
||||||
|
// reset
|
||||||
|
meta->sg.end.flag = 0;
|
||||||
|
}
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
while (meta->sg.start.flag != FLAG)
|
||||||
|
;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int ngpus, bool final_sync = false>
|
||||||
|
DINLINE void end_sync(const RankSignals &sg, volatile Metadata *meta,
|
||||||
|
int rank) {
|
||||||
|
constexpr auto FLAG = compute_flag(ngpus);
|
||||||
|
__syncthreads();
|
||||||
|
__shared__ int num;
|
||||||
|
if (threadIdx.x == 0) num = atomicAdd((int *)&meta->counter, 1);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Only the last completing block can perform the end synchronization
|
||||||
|
// This can ensures when the final busy wait ends, all ranks must have
|
||||||
|
// finished reading each other's buffer.
|
||||||
|
if (num == gridDim.x - 1) {
|
||||||
|
if (threadIdx.x == 32) {
|
||||||
|
// reset in a different warp
|
||||||
|
meta->counter = 0;
|
||||||
|
meta->sg.start.flag = 0;
|
||||||
|
} else if (threadIdx.x < ngpus) {
|
||||||
|
// simultaneously write to the corresponding byte to all other ranks.
|
||||||
|
// Latency = 1 p2p write
|
||||||
|
sg.signals[threadIdx.x]->end.data[rank] = 255;
|
||||||
|
}
|
||||||
|
// if this is the final sync, only one block needs it
|
||||||
|
// because kernel exit can serve as sync
|
||||||
|
if constexpr (final_sync) {
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
while (meta->sg.end.flag != FLAG)
|
||||||
|
;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if constexpr (!final_sync) {
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
while (meta->sg.end.flag != FLAG)
|
||||||
|
;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename P, int ngpus, typename A>
|
||||||
|
DINLINE P packed_reduce(const P *ptrs[], int idx) {
|
||||||
|
A tmp = upcast(ptrs[0][idx]);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 1; i < ngpus; i++) {
|
||||||
|
packed_assign_add(tmp, upcast(ptrs[i][idx]));
|
||||||
|
}
|
||||||
|
return downcast<P>(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int ngpus>
|
||||||
|
__global__ void __launch_bounds__(512, 1)
|
||||||
|
cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
|
||||||
|
volatile Metadata *meta, T *__restrict__ result,
|
||||||
|
int rank, int size) {
|
||||||
|
using P = typename packed_t<T>::P;
|
||||||
|
using A = typename packed_t<T>::A;
|
||||||
|
// note: we don't reorder the address so the accumulation order is the same
|
||||||
|
// for all ranks, ensuring bitwise identical results
|
||||||
|
auto dp = *_dp;
|
||||||
|
start_sync<ngpus>(sg, meta, rank);
|
||||||
|
// do the actual reduction
|
||||||
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||||
|
idx += gridDim.x * blockDim.x) {
|
||||||
|
((P *)result)[idx] =
|
||||||
|
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
|
||||||
|
}
|
||||||
|
end_sync<ngpus, true>(sg, meta, rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename P>
|
||||||
|
DINLINE P *get_tmp_buf(volatile Signal *sg) {
|
||||||
|
return (P *)(((Metadata *)sg) + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int ngpus>
|
||||||
|
__global__ void __launch_bounds__(512, 1)
|
||||||
|
cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
|
||||||
|
volatile Metadata *meta, T *__restrict__ result,
|
||||||
|
int rank, int size) {
|
||||||
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
int stride = gridDim.x * blockDim.x;
|
||||||
|
using P = typename packed_t<T>::P;
|
||||||
|
using A = typename packed_t<T>::A;
|
||||||
|
int part = size / ngpus;
|
||||||
|
int start = rank * part;
|
||||||
|
int end = rank == ngpus - 1 ? size : start + part;
|
||||||
|
const P *ptrs[ngpus];
|
||||||
|
P *tmps[ngpus];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < ngpus; i++) {
|
||||||
|
int target = (rank + i) % ngpus;
|
||||||
|
ptrs[i] = (const P *)_dp->ptrs[target];
|
||||||
|
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
||||||
|
}
|
||||||
|
auto tmp_out = tmps[0];
|
||||||
|
start_sync<ngpus>(sg, meta, rank);
|
||||||
|
// stage 1: reduce scatter
|
||||||
|
for (int idx = start + tid; idx < end; idx += stride) {
|
||||||
|
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
||||||
|
}
|
||||||
|
// Maybe TODO: replace this with per-block release-acquire
|
||||||
|
// can save about 1-2us (not a lot though)
|
||||||
|
end_sync<ngpus>(sg, meta, rank);
|
||||||
|
|
||||||
|
// stage 2: allgather
|
||||||
|
for (int idx = tid; idx < part; idx += stride) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < ngpus; i++) {
|
||||||
|
int dst_idx = ((rank + i) % ngpus) * part + idx;
|
||||||
|
((P *)result)[dst_idx] = tmps[i][idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// process the last larger partition
|
||||||
|
int remaining = size - part * ngpus;
|
||||||
|
if (tid < remaining) {
|
||||||
|
int dst_idx = tid + part * ngpus;
|
||||||
|
((P *)result)[dst_idx] = get_tmp_buf<P>(sg.signals[ngpus - 1])[part + tid];
|
||||||
|
}
|
||||||
|
|
||||||
|
// faster than this
|
||||||
|
// for (int idx = tid; idx < size; idx += stride) {
|
||||||
|
// int target_rank = idx / part;
|
||||||
|
// if (target_rank == ngpus) target_rank -= 1;
|
||||||
|
// ((P *)result)[idx] = tmps[target_rank][idx - target_rank * part];
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int ngpus>
|
||||||
|
__global__ void __launch_bounds__(512, 1)
|
||||||
|
cross_device_reduce_half_butterfly(RankData *_dp, RankSignals sg,
|
||||||
|
volatile Metadata *meta,
|
||||||
|
T *__restrict__ result, int rank,
|
||||||
|
int size) {
|
||||||
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
int stride = gridDim.x * blockDim.x;
|
||||||
|
using P = typename packed_t<T>::P;
|
||||||
|
using A = typename packed_t<T>::A;
|
||||||
|
auto tmp_out = get_tmp_buf<P>(sg.signals[rank]);
|
||||||
|
constexpr int hg = ngpus / 2;
|
||||||
|
// Actually not quite half butterfly.
|
||||||
|
// This is an all-to-all within each group containing half of the ranks
|
||||||
|
// followed by cross-group add. Equivalent to half butterfly when there
|
||||||
|
// are 4 GPUs, a common case for PCIe cards like T4 and A10.
|
||||||
|
const P *ptrs[hg];
|
||||||
|
{
|
||||||
|
int start = rank - rank % hg;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < hg; i++) {
|
||||||
|
ptrs[i] = (const P *)_dp->ptrs[i + start];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
start_sync<ngpus>(sg, meta, rank);
|
||||||
|
for (int idx = tid; idx < size; idx += stride) {
|
||||||
|
tmp_out[idx] = packed_reduce<P, hg, A>(ptrs, idx);
|
||||||
|
}
|
||||||
|
end_sync<ngpus>(sg, meta, rank);
|
||||||
|
|
||||||
|
auto src = get_tmp_buf<P>(sg.signals[(ngpus - 1) - rank % ngpus]);
|
||||||
|
// do the cross group reduction
|
||||||
|
for (int idx = tid; idx < size; idx += stride) {
|
||||||
|
auto tmp = tmp_out[idx];
|
||||||
|
packed_assign_add(tmp, src[idx]);
|
||||||
|
((P *)result)[idx] = tmp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
|
||||||
|
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
|
||||||
|
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
|
||||||
|
|
||||||
|
class CustomAllreduce {
|
||||||
|
public:
|
||||||
|
int rank_;
|
||||||
|
int world_size_;
|
||||||
|
bool full_nvlink_;
|
||||||
|
|
||||||
|
// below are device pointers
|
||||||
|
RankSignals sg_;
|
||||||
|
std::unordered_map<void *, RankData *> buffers_;
|
||||||
|
Metadata *meta_;
|
||||||
|
|
||||||
|
// stores the registered device pointers from all ranks
|
||||||
|
RankData *d_rank_data_base_, *d_rank_data_end_;
|
||||||
|
std::vector<void *> graph_unreg_buffers_;
|
||||||
|
// a map from IPC handles to opened IPC pointers
|
||||||
|
std::map<IPC_KEY, char *> ipc_handles_;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* meta is a pointer to device metadata and temporary buffer for allreduce.
|
||||||
|
*
|
||||||
|
* There's a total of sizeof(Metadata) of prefix before the actual data,
|
||||||
|
* so meta + 1 points to actual temporary buffer.
|
||||||
|
*
|
||||||
|
* note: this class does not own any device memory. Any required buffers
|
||||||
|
* are passed in from the constructor
|
||||||
|
*/
|
||||||
|
CustomAllreduce(Metadata *meta, void *rank_data, size_t rank_data_sz,
|
||||||
|
const cudaIpcMemHandle_t *handles,
|
||||||
|
const std::vector<int64_t> &offsets, int rank,
|
||||||
|
bool full_nvlink = true)
|
||||||
|
: rank_(rank),
|
||||||
|
world_size_(offsets.size()),
|
||||||
|
full_nvlink_(full_nvlink),
|
||||||
|
meta_(meta),
|
||||||
|
d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)),
|
||||||
|
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
||||||
|
for (int i = 0; i < world_size_; i++) {
|
||||||
|
Metadata *rank_meta;
|
||||||
|
if (i != rank_) {
|
||||||
|
char *handle = open_ipc_handle(&handles[i]);
|
||||||
|
handle += offsets[i];
|
||||||
|
rank_meta = (Metadata *)handle;
|
||||||
|
} else {
|
||||||
|
rank_meta = meta_;
|
||||||
|
}
|
||||||
|
sg_.signals[i] = &rank_meta->sg;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
char *open_ipc_handle(const void *ipc_handle) {
|
||||||
|
auto [it, new_handle] =
|
||||||
|
ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr});
|
||||||
|
if (new_handle) {
|
||||||
|
char *ipc_ptr;
|
||||||
|
CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr,
|
||||||
|
*((const cudaIpcMemHandle_t *)ipc_handle),
|
||||||
|
cudaIpcMemLazyEnablePeerAccess));
|
||||||
|
it->second = ipc_ptr;
|
||||||
|
}
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<uint8_t>, std::vector<int64_t>>
|
||||||
|
get_graph_buffer_ipc_meta() {
|
||||||
|
auto num_buffers = graph_unreg_buffers_.size();
|
||||||
|
auto handle_sz = sizeof(cudaIpcMemHandle_t);
|
||||||
|
std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
|
||||||
|
std::vector<int64_t> offsets(num_buffers);
|
||||||
|
for (int i = 0; i < num_buffers; i++) {
|
||||||
|
auto ptr = graph_unreg_buffers_[i];
|
||||||
|
void *base_ptr;
|
||||||
|
// note: must share the base address of each allocation, or we get wrong
|
||||||
|
// address
|
||||||
|
if (cuPointerGetAttribute(&base_ptr,
|
||||||
|
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||||
|
(CUdeviceptr)ptr) != CUDA_SUCCESS)
|
||||||
|
throw std::runtime_error("failed to get pointer attr");
|
||||||
|
CUDACHECK(cudaIpcGetMemHandle(
|
||||||
|
(cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr));
|
||||||
|
offsets[i] = ((char *)ptr) - ((char *)base_ptr);
|
||||||
|
}
|
||||||
|
return std::make_pair(handles, offsets);
|
||||||
|
}
|
||||||
|
|
||||||
|
void check_rank_data_capacity(size_t num = 1) {
|
||||||
|
if (d_rank_data_base_ + num > d_rank_data_end_)
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Rank data buffer is overflowed by " +
|
||||||
|
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void register_buffer(const std::vector<std::string> &handles,
|
||||||
|
const std::vector<int64_t> &offsets, void *self) {
|
||||||
|
check_rank_data_capacity();
|
||||||
|
RankData data;
|
||||||
|
for (int i = 0; i < world_size_; i++) {
|
||||||
|
if (i != rank_) {
|
||||||
|
char *handle = open_ipc_handle(handles[i].data());
|
||||||
|
handle += offsets[i];
|
||||||
|
data.ptrs[i] = handle;
|
||||||
|
} else {
|
||||||
|
data.ptrs[i] = self;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto d_data = d_rank_data_base_++;
|
||||||
|
CUDACHECK(
|
||||||
|
cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
|
||||||
|
buffers_[self] = d_data;
|
||||||
|
}
|
||||||
|
|
||||||
|
// note: when registering graph buffers, we intentionally choose to not
|
||||||
|
// deduplicate the addresses. That means if the allocator reuses some
|
||||||
|
// addresses, they will be registered again. This is to account for the remote
|
||||||
|
// possibility of different allocation patterns between ranks. For example,
|
||||||
|
// rank 1 may get the same input address for the second allreduce, but rank 2
|
||||||
|
// got a different address. IPC handles have internal reference counting
|
||||||
|
// mechanism so overhead should be small.
|
||||||
|
void register_graph_buffers(
|
||||||
|
const std::vector<std::string> &handles,
|
||||||
|
const std::vector<std::vector<int64_t>> &offsets) {
|
||||||
|
auto num_buffers = graph_unreg_buffers_.size();
|
||||||
|
check_rank_data_capacity(num_buffers);
|
||||||
|
std::vector<RankData> rank_data(num_buffers);
|
||||||
|
for (int i = 0; i < num_buffers; i++) {
|
||||||
|
auto self_ptr = graph_unreg_buffers_[i];
|
||||||
|
auto &rd = rank_data[i];
|
||||||
|
for (int j = 0; j < world_size_; j++) {
|
||||||
|
if (j != rank_) {
|
||||||
|
char *handle =
|
||||||
|
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
|
||||||
|
handle += offsets[j][i];
|
||||||
|
rd.ptrs[j] = handle;
|
||||||
|
} else {
|
||||||
|
rd.ptrs[j] = self_ptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
|
||||||
|
sizeof(RankData) * num_buffers,
|
||||||
|
cudaMemcpyHostToDevice));
|
||||||
|
d_rank_data_base_ += num_buffers;
|
||||||
|
graph_unreg_buffers_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This is the result after careful grid search. Using 36 blocks give the best
|
||||||
|
* or close to the best runtime on the devices I tried: A100, A10, A30, T4,
|
||||||
|
* V100. You'll notice that NCCL kernels also only take a small amount of SMs.
|
||||||
|
* Not quite sure the underlying reason, but my guess is that too many SMs
|
||||||
|
* will cause contention on NVLink bus.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
void allreduce(cudaStream_t stream, T *input, T *output, int size,
|
||||||
|
int threads = 512, int block_limit = 36) {
|
||||||
|
auto d = packed_t<T>::P::size;
|
||||||
|
if (size % d != 0)
|
||||||
|
throw std::runtime_error(
|
||||||
|
"custom allreduce currently requires input length to be multiple "
|
||||||
|
"of " +
|
||||||
|
std::to_string(d));
|
||||||
|
|
||||||
|
RankData *ptrs;
|
||||||
|
cudaStreamCaptureStatus status;
|
||||||
|
CUDACHECK(cudaStreamIsCapturing(stream, &status));
|
||||||
|
if (status == cudaStreamCaptureStatusActive) {
|
||||||
|
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
|
||||||
|
graph_unreg_buffers_.push_back(input);
|
||||||
|
} else {
|
||||||
|
auto it = buffers_.find(input);
|
||||||
|
if (it == buffers_.end())
|
||||||
|
throw std::runtime_error(
|
||||||
|
"buffer address " +
|
||||||
|
std::to_string(reinterpret_cast<uint64_t>(input)) +
|
||||||
|
" is not registered!");
|
||||||
|
ptrs = it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
size /= d;
|
||||||
|
auto bytes = size * sizeof(typename packed_t<T>::P);
|
||||||
|
int blocks = std::min(block_limit, (size + threads - 1) / threads);
|
||||||
|
#define KL(ngpus, name) \
|
||||||
|
name<T, ngpus> \
|
||||||
|
<<<blocks, threads, 0, stream>>>(ptrs, sg_, meta_, output, rank_, size);
|
||||||
|
#define REDUCE_CASE(ngpus) \
|
||||||
|
case ngpus: { \
|
||||||
|
if (world_size_ == 2) { \
|
||||||
|
KL(ngpus, cross_device_reduce_1stage); \
|
||||||
|
} else if (full_nvlink_) { \
|
||||||
|
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
|
||||||
|
(world_size_ <= 8 && bytes < 256 * 1024)) { \
|
||||||
|
KL(ngpus, cross_device_reduce_1stage); \
|
||||||
|
} else { \
|
||||||
|
KL(ngpus, cross_device_reduce_2stage); \
|
||||||
|
} \
|
||||||
|
} else { \
|
||||||
|
KL(ngpus, cross_device_reduce_half_butterfly); \
|
||||||
|
} \
|
||||||
|
break; \
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (world_size_) {
|
||||||
|
REDUCE_CASE(2)
|
||||||
|
REDUCE_CASE(4)
|
||||||
|
REDUCE_CASE(6)
|
||||||
|
REDUCE_CASE(8)
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
|
||||||
|
"gpus = " +
|
||||||
|
std::to_string(world_size_));
|
||||||
|
}
|
||||||
|
#undef REDUCE_CASE
|
||||||
|
#undef KL
|
||||||
|
}
|
||||||
|
|
||||||
|
~CustomAllreduce() {
|
||||||
|
for (auto [_, ptr] : ipc_handles_) {
|
||||||
|
CUDACHECK(cudaIpcCloseMemHandle(ptr));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
/**
|
||||||
|
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
||||||
|
a template instantiation:
|
||||||
|
* template void CustomAllreduce::allreduce<half>(cudaStream_t, half *, half *,
|
||||||
|
int, int, int);
|
||||||
|
*/
|
||||||
|
} // namespace vllm
|
||||||
284
csrc/custom_all_reduce_test.cu
Normal file
284
csrc/custom_all_reduce_test.cu
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
/**
|
||||||
|
* This is a standalone test for custom allreduce.
|
||||||
|
* To compile, make sure you have MPI and NCCL installed in your system.
|
||||||
|
* export MPI_HOME=XXX
|
||||||
|
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
|
||||||
|
* custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
|
||||||
|
*
|
||||||
|
* Warning: this C++ test is not designed to be very readable and was used
|
||||||
|
* during the rapid prototyping process.
|
||||||
|
*
|
||||||
|
* To run:
|
||||||
|
* mpirun -np 8 ./custom_all_reduce_test
|
||||||
|
*/
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <curand_kernel.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "cuda_profiler_api.h"
|
||||||
|
#include "custom_all_reduce.cuh"
|
||||||
|
#include "mpi.h"
|
||||||
|
#include "nccl.h"
|
||||||
|
|
||||||
|
#define MPICHECK(cmd) \
|
||||||
|
do { \
|
||||||
|
int e = cmd; \
|
||||||
|
if (e != MPI_SUCCESS) { \
|
||||||
|
printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \
|
||||||
|
exit(EXIT_FAILURE); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define NCCLCHECK(cmd) \
|
||||||
|
do { \
|
||||||
|
ncclResult_t r = cmd; \
|
||||||
|
if (r != ncclSuccess) { \
|
||||||
|
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
|
||||||
|
ncclGetErrorString(r)); \
|
||||||
|
exit(EXIT_FAILURE); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
__global__ void dummy_kernel() {
|
||||||
|
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void set_data(T *data, int size, int myRank) {
|
||||||
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||||
|
idx += gridDim.x * blockDim.x) {
|
||||||
|
data[idx] = myRank * 0.11f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void convert_data(const T *data1, const T *data2, double *fdata1,
|
||||||
|
double *fdata2, int size) {
|
||||||
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||||
|
idx += gridDim.x * blockDim.x) {
|
||||||
|
fdata1[idx] = data1[idx];
|
||||||
|
fdata2[idx] = data2[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void init_rand(curandState_t *state, int size, int nRanks) {
|
||||||
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||||
|
idx += gridDim.x * blockDim.x) {
|
||||||
|
for (int i = 0; i < nRanks; i++) {
|
||||||
|
curand_init(i + 1, idx, 0, &state[idx * nRanks + i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
|
||||||
|
int myRank, int nRanks, int size) {
|
||||||
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||||
|
idx += gridDim.x * blockDim.x) {
|
||||||
|
double sum = 0.0;
|
||||||
|
for (int i = 0; i < nRanks; i++) {
|
||||||
|
double val = curand_uniform_double(&state[idx * nRanks + i]) * 4;
|
||||||
|
T hval = val; // downcast first
|
||||||
|
sum += static_cast<double>(hval);
|
||||||
|
if (i == myRank) data[idx] = hval;
|
||||||
|
}
|
||||||
|
ground_truth[idx] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
||||||
|
int data_size) {
|
||||||
|
T *result;
|
||||||
|
cudaStream_t stream;
|
||||||
|
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||||
|
CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
|
||||||
|
CUDACHECK(cudaMemset(result, 0, data_size * sizeof(T)));
|
||||||
|
|
||||||
|
cudaIpcMemHandle_t self_data_handle;
|
||||||
|
cudaIpcMemHandle_t data_handles[8];
|
||||||
|
vllm::Metadata *buffer;
|
||||||
|
T *self_data_copy;
|
||||||
|
/**
|
||||||
|
* Allocate IPC buffer
|
||||||
|
*
|
||||||
|
* The first section is a temporary buffer for storing intermediate allreduce
|
||||||
|
* results, if a particular algorithm requires it. The second section is for
|
||||||
|
* the input to the allreduce. The actual API takes the input pointer as an
|
||||||
|
* argument (that is, they can and usually should be allocated separately).
|
||||||
|
* But since the input pointers and the temporary buffer all require IPC
|
||||||
|
* registration, they are allocated and registered together in the test for
|
||||||
|
* convenience.
|
||||||
|
*/
|
||||||
|
CUDACHECK(
|
||||||
|
cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Metadata)));
|
||||||
|
CUDACHECK(cudaMemset(buffer, 0,
|
||||||
|
2 * data_size * sizeof(T) + sizeof(vllm::Metadata)));
|
||||||
|
CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T)));
|
||||||
|
CUDACHECK(cudaIpcGetMemHandle(&self_data_handle, buffer));
|
||||||
|
|
||||||
|
MPICHECK(MPI_Allgather(&self_data_handle, sizeof(cudaIpcMemHandle_t),
|
||||||
|
MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
|
||||||
|
MPI_BYTE, MPI_COMM_WORLD));
|
||||||
|
|
||||||
|
void *rank_data;
|
||||||
|
size_t rank_data_sz = 16 * 1024 * 1024;
|
||||||
|
CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
|
||||||
|
std::vector<int64_t> offsets(nRanks, 0);
|
||||||
|
vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
|
||||||
|
offsets, myRank);
|
||||||
|
auto *self_data =
|
||||||
|
reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) +
|
||||||
|
sizeof(vllm::Metadata) + data_size * sizeof(T));
|
||||||
|
// hack buffer registration
|
||||||
|
{
|
||||||
|
std::vector<std::string> handles;
|
||||||
|
handles.reserve(nRanks);
|
||||||
|
for (int i = 0; i < nRanks; i++) {
|
||||||
|
char *begin = (char *)&data_handles[i];
|
||||||
|
char *end = (char *)&data_handles[i + 1];
|
||||||
|
handles.emplace_back(begin, end);
|
||||||
|
}
|
||||||
|
std::vector<int64_t> offsets(
|
||||||
|
nRanks, sizeof(vllm::Metadata) + data_size * sizeof(T));
|
||||||
|
fa.register_buffer(handles, offsets, self_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
double *ground_truth;
|
||||||
|
CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
|
||||||
|
curandState_t *states;
|
||||||
|
CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
|
||||||
|
init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
|
||||||
|
gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
|
||||||
|
nRanks, data_size);
|
||||||
|
CUDACHECK(cudaMemcpyAsync(self_data_copy, self_data, data_size * sizeof(T),
|
||||||
|
cudaMemcpyDeviceToDevice, stream));
|
||||||
|
cudaEvent_t start, stop;
|
||||||
|
CUDACHECK(cudaEventCreate(&start));
|
||||||
|
CUDACHECK(cudaEventCreate(&stop));
|
||||||
|
|
||||||
|
ncclDataType_t ncclDtype;
|
||||||
|
if (std::is_same<T, half>::value) {
|
||||||
|
ncclDtype = ncclFloat16;
|
||||||
|
} else if (std::is_same<T, nv_bfloat16>::value) {
|
||||||
|
ncclDtype = ncclBfloat16;
|
||||||
|
} else {
|
||||||
|
ncclDtype = ncclFloat;
|
||||||
|
}
|
||||||
|
|
||||||
|
dummy_kernel<<<1, 1, 0, stream>>>();
|
||||||
|
constexpr int warmup_iters = 5;
|
||||||
|
constexpr int num_iters = 25;
|
||||||
|
// warmup
|
||||||
|
for (int i = 0; i < warmup_iters; i++) {
|
||||||
|
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm,
|
||||||
|
stream));
|
||||||
|
}
|
||||||
|
CUDACHECK(cudaEventRecord(start, stream));
|
||||||
|
for (int i = 0; i < num_iters; i++) {
|
||||||
|
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm,
|
||||||
|
stream));
|
||||||
|
}
|
||||||
|
CUDACHECK(cudaEventRecord(stop, stream));
|
||||||
|
CUDACHECK(cudaStreamSynchronize(stream));
|
||||||
|
float allreduce_ms = 0;
|
||||||
|
cudaEventElapsedTime(&allreduce_ms, start, stop);
|
||||||
|
|
||||||
|
// if (myRank == 1) dummy_kernel<<<1, 1, 0, stream>>>();
|
||||||
|
// set_data<T><<<16, 1024, 0, stream>>>(self_data, data_size, myRank);
|
||||||
|
|
||||||
|
dummy_kernel<<<1, 1, 0, stream>>>();
|
||||||
|
// warm up
|
||||||
|
for (int i = 0; i < warmup_iters; i++) {
|
||||||
|
fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit);
|
||||||
|
}
|
||||||
|
CUDACHECK(cudaEventRecord(start, stream));
|
||||||
|
for (int i = 0; i < num_iters; i++) {
|
||||||
|
fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit);
|
||||||
|
}
|
||||||
|
CUDACHECK(cudaEventRecord(stop, stream));
|
||||||
|
CUDACHECK(cudaStreamSynchronize(stream));
|
||||||
|
|
||||||
|
float duration_ms = 0;
|
||||||
|
cudaEventElapsedTime(&duration_ms, start, stop);
|
||||||
|
if (myRank == 0)
|
||||||
|
printf(
|
||||||
|
"Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl "
|
||||||
|
"time:%.2fus\n",
|
||||||
|
myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit,
|
||||||
|
duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters);
|
||||||
|
|
||||||
|
// And wait for all the queued up work to complete
|
||||||
|
CUDACHECK(cudaStreamSynchronize(stream));
|
||||||
|
|
||||||
|
NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype,
|
||||||
|
ncclSum, comm, stream));
|
||||||
|
|
||||||
|
double *nccl_result, *my_result;
|
||||||
|
CUDACHECK(cudaMallocHost(&nccl_result, data_size * sizeof(double)));
|
||||||
|
CUDACHECK(cudaMallocHost(&my_result, data_size * sizeof(double)));
|
||||||
|
|
||||||
|
convert_data<T><<<108, 1024, 0, stream>>>(self_data, result, nccl_result,
|
||||||
|
my_result, data_size);
|
||||||
|
CUDACHECK(cudaStreamSynchronize(stream));
|
||||||
|
|
||||||
|
for (unsigned long j = 0; j < data_size; j++) {
|
||||||
|
auto diff = abs(nccl_result[j] - my_result[j]);
|
||||||
|
if (diff >= 1e-2) {
|
||||||
|
printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
|
||||||
|
myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
long double nccl_diffs = 0.0;
|
||||||
|
long double my_diffs = 0.0;
|
||||||
|
for (int j = 0; j < data_size; j++) {
|
||||||
|
nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
|
||||||
|
my_diffs += abs(my_result[j] - ground_truth[j]);
|
||||||
|
}
|
||||||
|
if (myRank == 0)
|
||||||
|
std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
|
||||||
|
<< " me: " << my_diffs / data_size << std::endl;
|
||||||
|
|
||||||
|
CUDACHECK(cudaFree(result));
|
||||||
|
CUDACHECK(cudaFree(self_data_copy));
|
||||||
|
CUDACHECK(cudaFree(rank_data));
|
||||||
|
CUDACHECK(cudaFree(buffer));
|
||||||
|
CUDACHECK(cudaFree(states));
|
||||||
|
CUDACHECK(cudaFreeHost(ground_truth));
|
||||||
|
CUDACHECK(cudaFreeHost(nccl_result));
|
||||||
|
CUDACHECK(cudaFreeHost(my_result));
|
||||||
|
CUDACHECK(cudaStreamDestroy(stream));
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
int nRanks, myRank;
|
||||||
|
MPICHECK(MPI_Init(&argc, &argv));
|
||||||
|
MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
|
||||||
|
MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks));
|
||||||
|
CUDACHECK(cudaSetDevice(myRank));
|
||||||
|
ncclUniqueId id;
|
||||||
|
ncclComm_t comm;
|
||||||
|
if (myRank == 0) ncclGetUniqueId(&id);
|
||||||
|
MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0,
|
||||||
|
MPI_COMM_WORLD));
|
||||||
|
NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
|
||||||
|
|
||||||
|
cudaProfilerStart();
|
||||||
|
// for (int threads : {256, 512}) {
|
||||||
|
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
|
||||||
|
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
for (int sz = 512; sz <= (32 << 20); sz *= 2) {
|
||||||
|
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 50);
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaProfilerStop();
|
||||||
|
return EXIT_SUCCESS;
|
||||||
|
}
|
||||||
@@ -2,6 +2,8 @@
|
|||||||
* Adapted from
|
* Adapted from
|
||||||
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
|
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
|
||||||
*/
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||||
@@ -12,3 +14,24 @@
|
|||||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||||
AT_DISPATCH_SWITCH( \
|
AT_DISPATCH_SWITCH( \
|
||||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
||||||
|
AT_DISPATCH_SWITCH( \
|
||||||
|
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
||||||
|
AT_DISPATCH_SWITCH( \
|
||||||
|
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||||
|
|||||||
@@ -1,14 +0,0 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
void rms_norm(
|
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input,
|
|
||||||
torch::Tensor& weight,
|
|
||||||
float epsilon);
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def(
|
|
||||||
"rms_norm",
|
|
||||||
&rms_norm,
|
|
||||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
|
||||||
}
|
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
#include "reduction_utils.cuh"
|
#include "reduction_utils.cuh"
|
||||||
@@ -9,8 +10,8 @@ namespace vllm {
|
|||||||
// TODO(woosuk): Further optimize this kernel.
|
// TODO(woosuk): Further optimize this kernel.
|
||||||
template<typename scalar_t>
|
template<typename scalar_t>
|
||||||
__global__ void rms_norm_kernel(
|
__global__ void rms_norm_kernel(
|
||||||
scalar_t* __restrict__ out, // [num_tokens, hidden_size]
|
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ input, // [num_tokens, hidden_size]
|
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
const float epsilon,
|
const float epsilon,
|
||||||
const int num_tokens,
|
const int num_tokens,
|
||||||
@@ -34,18 +35,49 @@ __global__ void rms_norm_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Further optimize this kernel.
|
||||||
|
template<typename scalar_t>
|
||||||
|
__global__ void fused_add_rms_norm_kernel(
|
||||||
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
|
const float epsilon,
|
||||||
|
const int num_tokens,
|
||||||
|
const int hidden_size) {
|
||||||
|
__shared__ float s_variance;
|
||||||
|
float variance = 0.0f;
|
||||||
|
|
||||||
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
|
float x = (float) input[blockIdx.x * hidden_size + idx];
|
||||||
|
x += (float) residual[blockIdx.x * hidden_size + idx];
|
||||||
|
variance += x * x;
|
||||||
|
residual[blockIdx.x * hidden_size + idx] = (scalar_t) x;
|
||||||
|
}
|
||||||
|
variance = blockReduceSum<float>(variance);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
|
float x = (float) residual[blockIdx.x * hidden_size + idx];
|
||||||
|
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void rms_norm(
|
void rms_norm(
|
||||||
torch::Tensor& out, // [num_tokens, hidden_size]
|
torch::Tensor& out, // [..., hidden_size]
|
||||||
torch::Tensor& input, // [num_tokens, hidden_size]
|
torch::Tensor& input, // [..., hidden_size]
|
||||||
torch::Tensor& weight, // [hidden_size]
|
torch::Tensor& weight, // [hidden_size]
|
||||||
float epsilon) {
|
float epsilon) {
|
||||||
int num_tokens = input.size(0);
|
int hidden_size = input.size(-1);
|
||||||
int hidden_size = input.size(1);
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(hidden_size, 1024));
|
dim3 block(std::min(hidden_size, 1024));
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
input.scalar_type(),
|
input.scalar_type(),
|
||||||
@@ -60,3 +92,29 @@ void rms_norm(
|
|||||||
hidden_size);
|
hidden_size);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void fused_add_rms_norm(
|
||||||
|
torch::Tensor& input, // [..., hidden_size]
|
||||||
|
torch::Tensor& residual, // [..., hidden_size]
|
||||||
|
torch::Tensor& weight, // [hidden_size]
|
||||||
|
float epsilon) {
|
||||||
|
int hidden_size = input.size(-1);
|
||||||
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(std::min(hidden_size, 1024));
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
input.scalar_type(),
|
||||||
|
"fused_add_rms_norm_kernel",
|
||||||
|
[&] {
|
||||||
|
vllm::fused_add_rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
residual.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|||||||
7
csrc/moe/moe_ops.cpp
Normal file
7
csrc/moe/moe_ops.cpp
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
#include "moe_ops.h"
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs.");
|
||||||
|
}
|
||||||
9
csrc/moe/moe_ops.h
Normal file
9
csrc/moe/moe_ops.h
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
void topk_softmax(
|
||||||
|
torch::Tensor& topk_weights,
|
||||||
|
torch::Tensor& topk_indices,
|
||||||
|
torch::Tensor& token_expert_indices,
|
||||||
|
torch::Tensor& gating_output);
|
||||||
499
csrc/moe/topk_softmax_kernels.cu
Normal file
499
csrc/moe/topk_softmax_kernels.cu
Normal file
@@ -0,0 +1,499 @@
|
|||||||
|
/*
|
||||||
|
* Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
|
||||||
|
* Copyright (c) 2024, The vLLM team.
|
||||||
|
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
|
#include <cub/cub.cuh>
|
||||||
|
#include <cub/util_type.cuh>
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace moe {
|
||||||
|
|
||||||
|
static constexpr int WARP_SIZE = 32;
|
||||||
|
|
||||||
|
/// Aligned array type
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
/// Number of elements in the array
|
||||||
|
int N,
|
||||||
|
/// Alignment requirement in bytes
|
||||||
|
int Alignment = sizeof(T) * N
|
||||||
|
>
|
||||||
|
class alignas(Alignment) AlignedArray {
|
||||||
|
float data[N];
|
||||||
|
};
|
||||||
|
|
||||||
|
// ====================== Softmax things ===============================
|
||||||
|
// We have our own implementation of softmax here so we can support transposing the output
|
||||||
|
// in the softmax kernel when we extend this module to support expert-choice routing.
|
||||||
|
template <int TPB>
|
||||||
|
__launch_bounds__(TPB) __global__
|
||||||
|
void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols)
|
||||||
|
{
|
||||||
|
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||||
|
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||||
|
|
||||||
|
__shared__ float normalizing_factor;
|
||||||
|
__shared__ float float_max;
|
||||||
|
|
||||||
|
const int thread_row_offset = blockIdx.x * num_cols;
|
||||||
|
|
||||||
|
cub::Sum sum;
|
||||||
|
float threadData(-FLT_MAX);
|
||||||
|
|
||||||
|
// Don't touch finished rows.
|
||||||
|
if ((finished != nullptr) && finished[blockIdx.x])
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
||||||
|
{
|
||||||
|
const int idx = thread_row_offset + ii;
|
||||||
|
threadData = max(static_cast<float>(input[idx]), threadData);
|
||||||
|
}
|
||||||
|
|
||||||
|
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||||
|
if (threadIdx.x == 0)
|
||||||
|
{
|
||||||
|
float_max = maxElem;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
threadData = 0;
|
||||||
|
|
||||||
|
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
||||||
|
{
|
||||||
|
const int idx = thread_row_offset + ii;
|
||||||
|
threadData += exp((static_cast<float>(input[idx]) - float_max));
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
|
||||||
|
|
||||||
|
if (threadIdx.x == 0)
|
||||||
|
{
|
||||||
|
normalizing_factor = 1.f / Z;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
||||||
|
{
|
||||||
|
const int idx = thread_row_offset + ii;
|
||||||
|
const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
|
||||||
|
output[idx] = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int TPB>
|
||||||
|
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
|
||||||
|
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
|
||||||
|
{
|
||||||
|
|
||||||
|
using cub_kvp = cub::KeyValuePair<int, float>;
|
||||||
|
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
||||||
|
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||||
|
|
||||||
|
cub_kvp thread_kvp;
|
||||||
|
cub::ArgMax arg_max;
|
||||||
|
|
||||||
|
const int num_rows = gridDim.x;
|
||||||
|
const int block_row = blockIdx.x;
|
||||||
|
|
||||||
|
const bool row_is_active = finished ? !finished[block_row] : true;
|
||||||
|
const int thread_read_offset = blockIdx.x * num_experts;
|
||||||
|
for (int k_idx = 0; k_idx < k; ++k_idx)
|
||||||
|
{
|
||||||
|
thread_kvp.key = 0;
|
||||||
|
thread_kvp.value = -1.f; // This is OK because inputs are probabilities
|
||||||
|
|
||||||
|
cub_kvp inp_kvp;
|
||||||
|
for (int expert = threadIdx.x; expert < num_experts; expert += TPB)
|
||||||
|
{
|
||||||
|
const int idx = thread_read_offset + expert;
|
||||||
|
inp_kvp.key = expert;
|
||||||
|
inp_kvp.value = inputs_after_softmax[idx];
|
||||||
|
|
||||||
|
for (int prior_k = 0; prior_k < k_idx; ++prior_k)
|
||||||
|
{
|
||||||
|
const int prior_winning_expert = indices[k * block_row + prior_k];
|
||||||
|
|
||||||
|
if (prior_winning_expert == expert)
|
||||||
|
{
|
||||||
|
inp_kvp = thread_kvp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
thread_kvp = arg_max(inp_kvp, thread_kvp);
|
||||||
|
}
|
||||||
|
|
||||||
|
const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
||||||
|
if (threadIdx.x == 0)
|
||||||
|
{
|
||||||
|
// Ignore experts the node isn't responsible for with expert parallelism
|
||||||
|
const int expert = result_kvp.key;
|
||||||
|
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
|
||||||
|
const bool should_process_row = row_is_active && node_uses_expert;
|
||||||
|
|
||||||
|
const int idx = k * block_row + k_idx;
|
||||||
|
output[idx] = result_kvp.value;
|
||||||
|
indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
|
||||||
|
assert(indices[idx] >= 0);
|
||||||
|
source_rows[idx] = k_idx * num_rows + block_row;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ====================== TopK softmax things ===============================
|
||||||
|
|
||||||
|
/*
|
||||||
|
A Top-K gating softmax written to exploit when the number of experts in the MoE layers
|
||||||
|
are a small power of 2. This allows us to cleanly share the rows among the threads in
|
||||||
|
a single warp and eliminate communication between warps (so no need to use shared mem).
|
||||||
|
|
||||||
|
It fuses the softmax, max and argmax into a single kernel.
|
||||||
|
|
||||||
|
Limitations:
|
||||||
|
1) This implementation is intended for when the number of experts is a small power of 2.
|
||||||
|
2) This implementation assumes k is small, but will work for any k.
|
||||||
|
*/
|
||||||
|
|
||||||
|
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
|
||||||
|
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||||
|
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
|
||||||
|
int* source_rows, const int k, const int start_expert, const int end_expert)
|
||||||
|
{
|
||||||
|
// We begin by enforcing compile time assertions and setting up compile time constants.
|
||||||
|
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
|
||||||
|
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
|
||||||
|
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
|
||||||
|
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
|
||||||
|
|
||||||
|
// Number of bytes each thread pulls in per load
|
||||||
|
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
||||||
|
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
|
||||||
|
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
|
||||||
|
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
|
||||||
|
|
||||||
|
// Restrictions based on previous section.
|
||||||
|
static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
|
||||||
|
static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
|
||||||
|
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
|
||||||
|
static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size");
|
||||||
|
|
||||||
|
// We have NUM_EXPERTS elements per row. We specialize for small #experts
|
||||||
|
static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
|
||||||
|
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
|
||||||
|
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
|
||||||
|
|
||||||
|
// Restrictions for previous section.
|
||||||
|
static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp");
|
||||||
|
|
||||||
|
// ===================== From this point, we finally start computing run-time variables. ========================
|
||||||
|
|
||||||
|
// Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps.
|
||||||
|
// This, each block processes a chunk of rows. We start by computing the start row for each block.
|
||||||
|
const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
|
||||||
|
|
||||||
|
// Now, using the base row per thread block, we compute the base row per warp.
|
||||||
|
const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
|
||||||
|
|
||||||
|
// The threads in a warp are split into sub-groups that will work on a row.
|
||||||
|
// We compute row offset for each thread sub-group
|
||||||
|
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
|
||||||
|
const int thread_row = warp_base_row + thread_row_in_warp;
|
||||||
|
|
||||||
|
// Threads with indices out of bounds should early exit here.
|
||||||
|
if (thread_row >= num_rows)
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const bool row_is_active = finished ? !finished[thread_row] : true;
|
||||||
|
|
||||||
|
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
|
||||||
|
// row it will read.
|
||||||
|
const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
|
||||||
|
|
||||||
|
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
|
||||||
|
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
|
||||||
|
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
|
||||||
|
const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
|
||||||
|
|
||||||
|
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
|
||||||
|
// this can support all powers of 2 up to 16.
|
||||||
|
// NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
|
||||||
|
// We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
|
||||||
|
using AccessType = AlignedArray<float, ELTS_PER_LDG>;
|
||||||
|
|
||||||
|
// Finally, we pull in the data from global mem
|
||||||
|
float row_chunk[VPT];
|
||||||
|
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);
|
||||||
|
const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
|
||||||
|
#pragma unroll
|
||||||
|
for (int ii = 0; ii < LDG_PER_THREAD; ++ii)
|
||||||
|
{
|
||||||
|
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
|
||||||
|
}
|
||||||
|
|
||||||
|
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
|
||||||
|
// convert to float afterwards for the exp + sum reduction.
|
||||||
|
float thread_max = row_chunk[0];
|
||||||
|
#pragma unroll
|
||||||
|
for (int ii = 1; ii < VPT; ++ii)
|
||||||
|
{
|
||||||
|
thread_max = max(thread_max, row_chunk[ii]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||||
|
{
|
||||||
|
thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
|
||||||
|
}
|
||||||
|
|
||||||
|
// From this point, thread max in all the threads have the max within the row.
|
||||||
|
// Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum.
|
||||||
|
float row_sum = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int ii = 0; ii < VPT; ++ii)
|
||||||
|
{
|
||||||
|
row_chunk[ii] = expf(row_chunk[ii] - thread_max);
|
||||||
|
row_sum += row_chunk[ii];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern.
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||||
|
{
|
||||||
|
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
|
||||||
|
}
|
||||||
|
|
||||||
|
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
|
||||||
|
// respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to
|
||||||
|
// compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row.
|
||||||
|
// However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the
|
||||||
|
// argmax after computing the softmax.
|
||||||
|
const float reciprocal_row_sum = 1.f / row_sum;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int ii = 0; ii < VPT; ++ii)
|
||||||
|
{
|
||||||
|
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
|
||||||
|
// with the max index.
|
||||||
|
int start_col = first_elt_read_by_thread;
|
||||||
|
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
|
||||||
|
|
||||||
|
for (int k_idx = 0; k_idx < k; ++k_idx)
|
||||||
|
{
|
||||||
|
// First, each thread does the local argmax
|
||||||
|
float max_val = row_chunk[0];
|
||||||
|
int expert = start_col;
|
||||||
|
#pragma unroll
|
||||||
|
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG)
|
||||||
|
{
|
||||||
|
#pragma unroll
|
||||||
|
for (int ii = 0; ii < ELTS_PER_LDG; ++ii)
|
||||||
|
{
|
||||||
|
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
|
||||||
|
|
||||||
|
// No check on the experts here since columns with the smallest index are processed first and only
|
||||||
|
// updated if > (not >=)
|
||||||
|
if (val > max_val)
|
||||||
|
{
|
||||||
|
max_val = val;
|
||||||
|
expert = col + ii;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max.
|
||||||
|
// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can
|
||||||
|
// then blank out their max with -inf and the warp can run more iterations...
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||||
|
{
|
||||||
|
float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
|
||||||
|
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
|
||||||
|
|
||||||
|
// We want lower indices to "win" in every thread so we break ties this way
|
||||||
|
if (other_max > max_val || (other_max == max_val && other_expert < expert))
|
||||||
|
{
|
||||||
|
max_val = other_max;
|
||||||
|
expert = other_expert;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the max for this k iteration to global memory.
|
||||||
|
if (thread_group_idx == 0)
|
||||||
|
{
|
||||||
|
// Add a guard to ignore experts not included by this node
|
||||||
|
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
|
||||||
|
const bool should_process_row = row_is_active && node_uses_expert;
|
||||||
|
|
||||||
|
// The lead thread from each sub-group will write out the final results to global memory. (This will be a
|
||||||
|
// single) thread per row of the input/output matrices.
|
||||||
|
const int idx = k * thread_row + k_idx;
|
||||||
|
output[idx] = max_val;
|
||||||
|
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
|
||||||
|
source_rows[idx] = k_idx * num_rows + thread_row;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
|
||||||
|
if (k_idx + 1 < k)
|
||||||
|
{
|
||||||
|
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
|
||||||
|
const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;
|
||||||
|
|
||||||
|
// Only the thread in the group which produced the max will reset the "winning" value to -inf.
|
||||||
|
if (thread_group_idx == thread_to_clear_in_group)
|
||||||
|
{
|
||||||
|
const int offset_for_expert = expert % ELTS_PER_LDG;
|
||||||
|
// Safe to set to any negative value since row_chunk values must be between 0 and 1.
|
||||||
|
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace detail
|
||||||
|
{
|
||||||
|
// Constructs some constants needed to partition the work across threads at compile time.
|
||||||
|
template <int EXPERTS, int BYTES_PER_LDG>
|
||||||
|
struct TopkConstants
|
||||||
|
{
|
||||||
|
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
||||||
|
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
|
||||||
|
static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
|
||||||
|
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
||||||
|
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
|
||||||
|
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
|
||||||
|
};
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
template <int EXPERTS, int WARPS_PER_TB>
|
||||||
|
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
|
||||||
|
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
||||||
|
{
|
||||||
|
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||||
|
|
||||||
|
static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
||||||
|
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
|
||||||
|
static constexpr int VPT = Constants::VPT;
|
||||||
|
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
||||||
|
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
||||||
|
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
||||||
|
|
||||||
|
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
|
||||||
|
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
|
||||||
|
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
|
||||||
|
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
|
||||||
|
gating_output, nullptr, topk_weights, topk_indicies, \
|
||||||
|
token_expert_indices, num_tokens, topk, 0, num_experts, \
|
||||||
|
stream);
|
||||||
|
|
||||||
|
void topkGatingSoftmaxKernelLauncher(
|
||||||
|
const float* gating_output,
|
||||||
|
float* topk_weights,
|
||||||
|
int* topk_indicies,
|
||||||
|
int* token_expert_indices,
|
||||||
|
float* softmax_workspace,
|
||||||
|
const int num_tokens,
|
||||||
|
const int num_experts,
|
||||||
|
const int topk,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
static constexpr int WARPS_PER_TB = 4;
|
||||||
|
switch (num_experts) {
|
||||||
|
case 1:
|
||||||
|
LAUNCH_SOFTMAX(1, WARPS_PER_TB);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
LAUNCH_SOFTMAX(2, WARPS_PER_TB);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
LAUNCH_SOFTMAX(4, WARPS_PER_TB);
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
LAUNCH_SOFTMAX(8, WARPS_PER_TB);
|
||||||
|
break;
|
||||||
|
case 16:
|
||||||
|
LAUNCH_SOFTMAX(16, WARPS_PER_TB);
|
||||||
|
break;
|
||||||
|
case 32:
|
||||||
|
LAUNCH_SOFTMAX(32, WARPS_PER_TB);
|
||||||
|
break;
|
||||||
|
case 64:
|
||||||
|
LAUNCH_SOFTMAX(64, WARPS_PER_TB);
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
LAUNCH_SOFTMAX(128, WARPS_PER_TB);
|
||||||
|
break;
|
||||||
|
case 256:
|
||||||
|
LAUNCH_SOFTMAX(256, WARPS_PER_TB);
|
||||||
|
break;
|
||||||
|
default: {
|
||||||
|
TORCH_CHECK(softmax_workspace != nullptr,
|
||||||
|
"softmax_workspace must be provided for num_experts that are not a power of 2.");
|
||||||
|
static constexpr int TPB = 256;
|
||||||
|
moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
|
||||||
|
gating_output, nullptr, softmax_workspace, num_experts);
|
||||||
|
moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
|
||||||
|
softmax_workspace, nullptr, topk_weights, topk_indicies, token_expert_indices,
|
||||||
|
num_experts, topk, 0, num_experts);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace moe
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
void topk_softmax(
|
||||||
|
torch::Tensor& topk_weights, // [num_tokens, topk]
|
||||||
|
torch::Tensor& topk_indices, // [num_tokens, topk]
|
||||||
|
torch::Tensor& token_expert_indices, // [num_tokens, topk]
|
||||||
|
torch::Tensor& gating_output) // [num_tokens, num_experts]
|
||||||
|
{
|
||||||
|
const int num_experts = gating_output.size(-1);
|
||||||
|
const int num_tokens = gating_output.numel() / num_experts;
|
||||||
|
const int topk = topk_weights.size(-1);
|
||||||
|
|
||||||
|
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||||
|
const bool needs_workspace = !is_pow_2 || num_experts > 256;
|
||||||
|
const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;
|
||||||
|
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
|
||||||
|
vllm::moe::topkGatingSoftmaxKernelLauncher(
|
||||||
|
gating_output.data_ptr<float>(),
|
||||||
|
topk_weights.data_ptr<float>(),
|
||||||
|
topk_indices.data_ptr<int>(),
|
||||||
|
token_expert_indices.data_ptr<int>(),
|
||||||
|
softmax_workspace.data_ptr<float>(),
|
||||||
|
num_tokens,
|
||||||
|
num_experts,
|
||||||
|
topk,
|
||||||
|
stream);
|
||||||
|
}
|
||||||
108
csrc/moe_align_block_size_kernels.cu
Normal file
108
csrc/moe_align_block_size_kernels.cu
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <THC/THCAtomics.cuh>
|
||||||
|
|
||||||
|
#include "cuda_compat.h"
|
||||||
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
|
const static size_t NUM_MAX_EXPERTS = 64;
|
||||||
|
#define CEILDIV(x,y) (((x) + (y) - 1) / (y))
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
||||||
|
int32_t *sorted_token_ids,
|
||||||
|
int32_t *expert_ids,
|
||||||
|
int32_t *total_tokens_post_pad,
|
||||||
|
int32_t num_experts,
|
||||||
|
int32_t block_size,
|
||||||
|
size_t numel) {
|
||||||
|
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
||||||
|
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||||
|
__shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS];
|
||||||
|
__shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1];
|
||||||
|
for (int i = 0; i < num_experts; ++i) {
|
||||||
|
tokens_cnts[threadIdx.x + 1][i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* In the first step we compute token_cnts[thread_index + 1][expert_index],
|
||||||
|
* which counts how many tokens in the token shard of thread_index are assigned
|
||||||
|
* to expert expert_index.
|
||||||
|
*/
|
||||||
|
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||||
|
++tokens_cnts[threadIdx.x + 1][topk_ids[i]];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// For each expert we accumulate the token counts from the different threads.
|
||||||
|
tokens_cnts[0][threadIdx.x] = 0;
|
||||||
|
for (int i = 1; i <= blockDim.x; ++i) {
|
||||||
|
tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// We accumulate the token counts of all experts in thread 0.
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
cumsum[0] = 0;
|
||||||
|
for (int i = 1; i <= num_experts; ++i) {
|
||||||
|
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[blockDim.x][i - 1], block_size) * block_size;
|
||||||
|
}
|
||||||
|
*total_tokens_post_pad = cumsum[num_experts];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* For each expert, each thread processes the tokens of the corresponding blocks
|
||||||
|
* and stores the corresponding expert_id for each block.
|
||||||
|
*/
|
||||||
|
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
|
||||||
|
expert_ids[i / block_size] = threadIdx.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Each thread processes a token shard, calculating the index of each token after
|
||||||
|
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
|
||||||
|
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
|
||||||
|
* where * represents a padding value(preset in python).
|
||||||
|
*/
|
||||||
|
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||||
|
int32_t expert_id = topk_ids[i];
|
||||||
|
/** The cumsum[expert_id] stores the starting index of the tokens that the
|
||||||
|
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
|
||||||
|
* stores the indices of the tokens processed by the expert with expert_id within
|
||||||
|
* the current thread's token shard.
|
||||||
|
*/
|
||||||
|
int32_t rank_post_pad = tokens_cnts[threadIdx.x][expert_id] + cumsum[expert_id];
|
||||||
|
sorted_token_ids[rank_post_pad] = i;
|
||||||
|
++tokens_cnts[threadIdx.x][expert_id];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void moe_align_block_size(
|
||||||
|
torch::Tensor topk_ids,
|
||||||
|
int num_experts,
|
||||||
|
int block_size,
|
||||||
|
torch::Tensor sorted_token_ids,
|
||||||
|
torch::Tensor experts_ids,
|
||||||
|
torch::Tensor num_tokens_post_pad) {
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
assert(num_experts <= NUM_MAX_EXPERTS);
|
||||||
|
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||||
|
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||||
|
vllm::moe_align_block_size_kernel<scalar_t><<<1, num_experts, 0, stream>>>(
|
||||||
|
topk_ids.data_ptr<scalar_t>(),
|
||||||
|
sorted_token_ids.data_ptr<int32_t>(),
|
||||||
|
experts_ids.data_ptr<int32_t>(),
|
||||||
|
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
topk_ids.numel());
|
||||||
|
});
|
||||||
|
}
|
||||||
145
csrc/ops.h
Normal file
145
csrc/ops.h
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
void paged_attention_v1(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& query,
|
||||||
|
torch::Tensor& key_cache,
|
||||||
|
torch::Tensor& value_cache,
|
||||||
|
int num_kv_heads,
|
||||||
|
float scale,
|
||||||
|
torch::Tensor& block_tables,
|
||||||
|
torch::Tensor& context_lens,
|
||||||
|
int block_size,
|
||||||
|
int max_context_len,
|
||||||
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
|
const std::string& kv_cache_dtype);
|
||||||
|
|
||||||
|
void paged_attention_v2(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& exp_sums,
|
||||||
|
torch::Tensor& max_logits,
|
||||||
|
torch::Tensor& tmp_out,
|
||||||
|
torch::Tensor& query,
|
||||||
|
torch::Tensor& key_cache,
|
||||||
|
torch::Tensor& value_cache,
|
||||||
|
int num_kv_heads,
|
||||||
|
float scale,
|
||||||
|
torch::Tensor& block_tables,
|
||||||
|
torch::Tensor& context_lens,
|
||||||
|
int block_size,
|
||||||
|
int max_context_len,
|
||||||
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
|
const std::string& kv_cache_dtype);
|
||||||
|
|
||||||
|
void rms_norm(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input,
|
||||||
|
torch::Tensor& weight,
|
||||||
|
float epsilon);
|
||||||
|
|
||||||
|
void fused_add_rms_norm(
|
||||||
|
torch::Tensor& input,
|
||||||
|
torch::Tensor& residual,
|
||||||
|
torch::Tensor& weight,
|
||||||
|
float epsilon);
|
||||||
|
|
||||||
|
void rotary_embedding(
|
||||||
|
torch::Tensor& positions,
|
||||||
|
torch::Tensor& query,
|
||||||
|
torch::Tensor& key,
|
||||||
|
int head_size,
|
||||||
|
torch::Tensor& cos_sin_cache,
|
||||||
|
bool is_neox);
|
||||||
|
|
||||||
|
void silu_and_mul(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input);
|
||||||
|
|
||||||
|
void gelu_and_mul(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input);
|
||||||
|
|
||||||
|
void gelu_new(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input);
|
||||||
|
|
||||||
|
void gelu_fast(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input);
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
torch::Tensor awq_gemm(
|
||||||
|
torch::Tensor _in_feats,
|
||||||
|
torch::Tensor _kernel,
|
||||||
|
torch::Tensor _scaling_factors,
|
||||||
|
torch::Tensor _zeros,
|
||||||
|
int split_k_iters);
|
||||||
|
|
||||||
|
torch::Tensor awq_dequantize(
|
||||||
|
torch::Tensor _kernel,
|
||||||
|
torch::Tensor _scaling_factors,
|
||||||
|
torch::Tensor _zeros,
|
||||||
|
int split_k_iters,
|
||||||
|
int thx,
|
||||||
|
int thy);
|
||||||
|
|
||||||
|
torch::Tensor marlin_gemm(
|
||||||
|
torch::Tensor& a,
|
||||||
|
torch::Tensor& b_q_weight,
|
||||||
|
torch::Tensor& b_scales,
|
||||||
|
torch::Tensor& workspace,
|
||||||
|
int64_t size_m,
|
||||||
|
int64_t size_n,
|
||||||
|
int64_t size_k);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void squeezellm_gemm(
|
||||||
|
torch::Tensor vec,
|
||||||
|
torch::Tensor mat,
|
||||||
|
torch::Tensor mul,
|
||||||
|
torch::Tensor lookup_table);
|
||||||
|
|
||||||
|
torch::Tensor gptq_gemm(
|
||||||
|
torch::Tensor a,
|
||||||
|
torch::Tensor b_q_weight,
|
||||||
|
torch::Tensor b_gptq_qzeros,
|
||||||
|
torch::Tensor b_gptq_scales,
|
||||||
|
torch::Tensor b_g_idx,
|
||||||
|
bool use_exllama,
|
||||||
|
int bit);
|
||||||
|
|
||||||
|
void gptq_shuffle(
|
||||||
|
torch::Tensor q_weight,
|
||||||
|
torch::Tensor q_perm,
|
||||||
|
int bit);
|
||||||
|
|
||||||
|
void moe_align_block_size(
|
||||||
|
torch::Tensor topk_ids,
|
||||||
|
int num_experts,
|
||||||
|
int block_size,
|
||||||
|
torch::Tensor sorted_token_ids,
|
||||||
|
torch::Tensor experts_ids,
|
||||||
|
torch::Tensor num_tokens_post_pad);
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
using fptr_t = uint64_t;
|
||||||
|
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
||||||
|
const std::vector<std::string> &handles,
|
||||||
|
const std::vector<int64_t> &offsets, int rank,
|
||||||
|
bool full_nvlink);
|
||||||
|
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
||||||
|
bool full_nvlink);
|
||||||
|
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out);
|
||||||
|
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
||||||
|
torch::Tensor &out);
|
||||||
|
void dispose(fptr_t _fa);
|
||||||
|
int meta_size();
|
||||||
|
void register_buffer(fptr_t _fa, torch::Tensor &t,
|
||||||
|
const std::vector<std::string> &handles,
|
||||||
|
const std::vector<int64_t> &offsets);
|
||||||
|
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
|
||||||
|
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
|
||||||
|
const std::vector<std::vector<int64_t>> &offsets);
|
||||||
|
#endif
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
void rotary_embedding(
|
|
||||||
torch::Tensor& positions,
|
|
||||||
torch::Tensor& query,
|
|
||||||
torch::Tensor& key,
|
|
||||||
int head_size,
|
|
||||||
torch::Tensor& cos_sin_cache,
|
|
||||||
bool is_neox);
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def(
|
|
||||||
"rotary_embedding",
|
|
||||||
&rotary_embedding,
|
|
||||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
|
||||||
}
|
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
|
#include "cuda_compat.h"
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
@@ -19,14 +21,14 @@ inline __device__ void apply_rotary_embedding(
|
|||||||
// GPT-NeoX style rotary embedding.
|
// GPT-NeoX style rotary embedding.
|
||||||
x_index = rot_offset;
|
x_index = rot_offset;
|
||||||
y_index = embed_dim + rot_offset;
|
y_index = embed_dim + rot_offset;
|
||||||
cos = __ldg(cos_ptr + x_index);
|
cos = VLLM_LDG(cos_ptr + x_index);
|
||||||
sin = __ldg(sin_ptr + x_index);
|
sin = VLLM_LDG(sin_ptr + x_index);
|
||||||
} else {
|
} else {
|
||||||
// GPT-J style rotary embedding.
|
// GPT-J style rotary embedding.
|
||||||
x_index = 2 * rot_offset;
|
x_index = 2 * rot_offset;
|
||||||
y_index = 2 * rot_offset + 1;
|
y_index = 2 * rot_offset + 1;
|
||||||
cos = __ldg(cos_ptr + x_index / 2);
|
cos = VLLM_LDG(cos_ptr + x_index / 2);
|
||||||
sin = __ldg(sin_ptr + x_index / 2);
|
sin = VLLM_LDG(sin_ptr + x_index / 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
const scalar_t x = arr[x_index];
|
const scalar_t x = arr[x_index];
|
||||||
@@ -37,13 +39,13 @@ inline __device__ void apply_rotary_embedding(
|
|||||||
|
|
||||||
template<typename scalar_t, bool IS_NEOX>
|
template<typename scalar_t, bool IS_NEOX>
|
||||||
__global__ void rotary_embedding_kernel(
|
__global__ void rotary_embedding_kernel(
|
||||||
const int64_t* __restrict__ positions, // [num_tokens]
|
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
||||||
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
||||||
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
|
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
||||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||||
const int rot_dim,
|
const int rot_dim,
|
||||||
const int query_stride,
|
const int64_t query_stride,
|
||||||
const int key_stride,
|
const int64_t key_stride,
|
||||||
const int num_heads,
|
const int num_heads,
|
||||||
const int num_kv_heads,
|
const int num_kv_heads,
|
||||||
const int head_size) {
|
const int head_size) {
|
||||||
@@ -59,7 +61,7 @@ __global__ void rotary_embedding_kernel(
|
|||||||
const int nq = num_heads * embed_dim;
|
const int nq = num_heads * embed_dim;
|
||||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||||
const int head_idx = i / embed_dim;
|
const int head_idx = i / embed_dim;
|
||||||
const int token_head = token_idx * query_stride + head_idx * head_size;
|
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
||||||
const int rot_offset = i % embed_dim;
|
const int rot_offset = i % embed_dim;
|
||||||
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
|
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
|
||||||
sin_ptr, rot_offset, embed_dim);
|
sin_ptr, rot_offset, embed_dim);
|
||||||
@@ -68,7 +70,7 @@ __global__ void rotary_embedding_kernel(
|
|||||||
const int nk = num_kv_heads * embed_dim;
|
const int nk = num_kv_heads * embed_dim;
|
||||||
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||||
const int head_idx = i / embed_dim;
|
const int head_idx = i / embed_dim;
|
||||||
const int token_head = token_idx * key_stride + head_idx * head_size;
|
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||||
const int rot_offset = i % embed_dim;
|
const int rot_offset = i % embed_dim;
|
||||||
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
|
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
|
||||||
sin_ptr, rot_offset, embed_dim);
|
sin_ptr, rot_offset, embed_dim);
|
||||||
@@ -78,21 +80,22 @@ __global__ void rotary_embedding_kernel(
|
|||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void rotary_embedding(
|
void rotary_embedding(
|
||||||
torch::Tensor& positions, // [num_tokens]
|
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||||
torch::Tensor& query, // [num_tokens, num_heads * head_size]
|
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
|
||||||
torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
|
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
|
||||||
int head_size,
|
int head_size,
|
||||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||||
bool is_neox) {
|
bool is_neox) {
|
||||||
int num_tokens = query.size(0);
|
int64_t num_tokens = query.numel() / query.size(-1);
|
||||||
int rot_dim = cos_sin_cache.size(1);
|
int rot_dim = cos_sin_cache.size(1);
|
||||||
int num_heads = query.size(1) / head_size;
|
int num_heads = query.size(-1) / head_size;
|
||||||
int num_kv_heads = key.size(1) / head_size;
|
int num_kv_heads = key.size(-1) / head_size;
|
||||||
int query_stride = query.stride(0);
|
int64_t query_stride = query.stride(-2);
|
||||||
int key_stride = key.stride(0);
|
int64_t key_stride = key.stride(-2);
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
query.scalar_type(),
|
query.scalar_type(),
|
||||||
|
|||||||
217
csrc/punica/LICENSE
Normal file
217
csrc/punica/LICENSE
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
Contains code from https://github.com/punica-ai/punica
|
||||||
|
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright {yyyy} {name of copyright owner}
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
||||||
|
------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
This product bundles various third-party components under other open source licenses.
|
||||||
|
This section summarizes those components and their licenses. See licenses/
|
||||||
|
for text of these licenses.
|
||||||
|
|
||||||
|
|
||||||
|
Apache-2.0
|
||||||
|
* third_party/nvbench (with LLVM exception)
|
||||||
|
* third_party/flashinfer
|
||||||
|
|
||||||
|
BSD-3-Clause:
|
||||||
|
* third_party/cutlass
|
||||||
4
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half)
|
||||||
4
csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half)
|
||||||
4
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)
|
||||||
61
csrc/punica/bgmv/bgmv_config.h
Normal file
61
csrc/punica/bgmv/bgmv_config.h
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
template <int feat_in, int feat_out, typename in_T, typename out_T,
|
||||||
|
typename W_T>
|
||||||
|
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||||
|
const W_T *__restrict__ W,
|
||||||
|
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||||
|
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
|
||||||
|
int64_t layer_idx, float scale);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
|
||||||
|
#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 128) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 256) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 512) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 1024) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 1280) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 1728) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 1792) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 2048) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 2560) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 2752) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 3072) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 3456) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 3584) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 4096) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 5120) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 5504) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 5632) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 6144) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 6912) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 7168) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 8192) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 9216) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 10240) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 11008) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 12288) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 13824) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 14336) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 16384) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 20480) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 24576) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 28672) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 32000) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 32256) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 32512) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 32768) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 33024) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 36864) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 49152) \
|
||||||
|
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
|
||||||
|
|
||||||
|
// Keep this in sync with vllm/config::LoRAConfig
|
||||||
|
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
|
||||||
|
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
|
||||||
|
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
|
||||||
|
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
|
||||||
|
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
|
||||||
|
|
||||||
|
// clang-format on
|
||||||
4
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)
|
||||||
4
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
|
||||||
4
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
|
||||||
4
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)
|
||||||
4
csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
|
||||||
4
csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16)
|
||||||
4
csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half)
|
||||||
294
csrc/punica/bgmv/bgmv_impl.cuh
Normal file
294
csrc/punica/bgmv/bgmv_impl.cuh
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cuda/pipeline>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include "vec_dtypes.cuh"
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
// nthrs = (32, 4)
|
||||||
|
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
|
||||||
|
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
|
||||||
|
typename out_T, typename W_T>
|
||||||
|
__global__ void
|
||||||
|
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||||
|
const W_T *__restrict__ W,
|
||||||
|
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||||
|
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
|
||||||
|
float scale) {
|
||||||
|
size_t batch_idx = blockIdx.y;
|
||||||
|
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
|
||||||
|
if (idx < 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
size_t j = blockIdx.x;
|
||||||
|
constexpr size_t num_pipeline_stages = 2;
|
||||||
|
constexpr size_t tile_size = tx * ty * vec_size;
|
||||||
|
__shared__ W_T W_shared[num_pipeline_stages * tile_size];
|
||||||
|
__shared__ in_T X_shared[num_pipeline_stages * tile_size];
|
||||||
|
__shared__ float y_warpwise[ty];
|
||||||
|
|
||||||
|
size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
|
||||||
|
size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
|
||||||
|
auto pipe = cuda::make_pipeline();
|
||||||
|
|
||||||
|
// pipeline load W/X and compute WX;
|
||||||
|
pipe.producer_acquire();
|
||||||
|
cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
W + (idx * feat_out + j) * feat_in +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
|
||||||
|
cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
X + (batch_idx * feat_in) +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
|
||||||
|
pipe.producer_commit();
|
||||||
|
size_t copy_idx, compute_idx;
|
||||||
|
float y = 0.f;
|
||||||
|
vec_t<in_T, vec_size> x_vec;
|
||||||
|
vec_t<W_T, vec_size> w_vec;
|
||||||
|
size_t tile_idx;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
|
||||||
|
++tile_idx) {
|
||||||
|
copy_idx = tile_idx % num_pipeline_stages;
|
||||||
|
// pipeline stage: async copy W fragment
|
||||||
|
pipe.producer_acquire();
|
||||||
|
if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
|
||||||
|
cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
W + (idx * feat_out + j) * feat_in +
|
||||||
|
tile_idx * tile_size +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
|
||||||
|
cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
X + (batch_idx * feat_in) + tile_idx * tile_size +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
|
||||||
|
}
|
||||||
|
pipe.producer_commit();
|
||||||
|
|
||||||
|
compute_idx = (tile_idx - 1) % num_pipeline_stages;
|
||||||
|
// pipeline stage: compute WX
|
||||||
|
pipe.consumer_wait();
|
||||||
|
block.sync();
|
||||||
|
x_vec.load(X_shared + X_shared_offset[compute_idx] +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||||
|
w_vec.load(W_shared + W_shared_offset[compute_idx] +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||||
|
float sum = 0.f;
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t i = 0; i < vec_size; ++i) {
|
||||||
|
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||||
|
sum += __shfl_down_sync(0xffffffff, sum, offset);
|
||||||
|
}
|
||||||
|
y_warpwise[threadIdx.y] = sum;
|
||||||
|
block.sync();
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t i = 0; i < ty; ++i) {
|
||||||
|
y += y_warpwise[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
block.sync();
|
||||||
|
pipe.consumer_release();
|
||||||
|
}
|
||||||
|
|
||||||
|
compute_idx = (tile_idx - 1) % num_pipeline_stages;
|
||||||
|
// final pipeline stage
|
||||||
|
pipe.consumer_wait();
|
||||||
|
block.sync();
|
||||||
|
x_vec.load(X_shared + X_shared_offset[compute_idx] +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||||
|
w_vec.load(W_shared + W_shared_offset[compute_idx] +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||||
|
float sum = 0.f;
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t i = 0; i < vec_size; ++i) {
|
||||||
|
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||||
|
sum += __shfl_down_sync(0xffffffff, sum, offset);
|
||||||
|
}
|
||||||
|
y_warpwise[threadIdx.y] =
|
||||||
|
((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
|
||||||
|
? sum
|
||||||
|
: 0.f;
|
||||||
|
block.sync();
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t i = 0; i < ty; ++i) {
|
||||||
|
y += y_warpwise[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
block.sync();
|
||||||
|
pipe.consumer_release();
|
||||||
|
|
||||||
|
// write Y;
|
||||||
|
if (block.thread_rank() == 0) {
|
||||||
|
Y[batch_idx * full_y_size + y_offset + j] += static_cast<out_T>(y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// nthrs = (2, 16, 4)
|
||||||
|
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
|
||||||
|
typename in_T, typename out_T, typename W_T>
|
||||||
|
__global__ void
|
||||||
|
bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||||
|
const W_T *__restrict__ W,
|
||||||
|
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||||
|
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
|
||||||
|
float scale) {
|
||||||
|
size_t batch_idx = blockIdx.y;
|
||||||
|
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
|
||||||
|
|
||||||
|
if (idx < 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
size_t tile_idx = blockIdx.x;
|
||||||
|
|
||||||
|
// load X;
|
||||||
|
vec_t<in_T, vec_size> x_vec;
|
||||||
|
x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);
|
||||||
|
|
||||||
|
// load W;
|
||||||
|
vec_t<W_T, vec_size> w_vec;
|
||||||
|
w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in +
|
||||||
|
block.thread_rank() * vec_size);
|
||||||
|
|
||||||
|
float sum = 0.f;
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t i = 0; i < vec_size; ++i) {
|
||||||
|
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
cg::thread_block_tile g = cg::tiled_partition<tx>(block);
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||||
|
sum += g.shfl_down(sum, offset);
|
||||||
|
}
|
||||||
|
sum = g.shfl(sum, 0);
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
|
||||||
|
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int feat_in, int feat_out, typename in_T, typename out_T,
|
||||||
|
typename W_T>
|
||||||
|
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||||
|
const W_T *__restrict__ W,
|
||||||
|
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||||
|
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
|
||||||
|
int64_t layer_idx, float scale) {
|
||||||
|
constexpr size_t vec_size = 8;
|
||||||
|
constexpr int tz = 4;
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
if constexpr (feat_in < feat_out) {
|
||||||
|
static_assert(feat_in % vec_size == 0);
|
||||||
|
constexpr int tx = feat_in / vec_size;
|
||||||
|
|
||||||
|
static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) ||
|
||||||
|
(16 % tx == 0 && feat_out % (16 / tx * tz) == 0) ||
|
||||||
|
(8 % tx == 0 && feat_out % (8 / tx * tz) == 0));
|
||||||
|
|
||||||
|
if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) {
|
||||||
|
constexpr int ty = 32 / tx;
|
||||||
|
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||||
|
dim3 nthrs(tx, ty, tz);
|
||||||
|
|
||||||
|
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||||
|
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||||
|
full_y_size, num_layers, layer_idx,
|
||||||
|
scale);
|
||||||
|
} else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) {
|
||||||
|
constexpr int ty = 16 / tx;
|
||||||
|
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||||
|
dim3 nthrs(tx, ty, tz);
|
||||||
|
|
||||||
|
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||||
|
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||||
|
full_y_size, num_layers, layer_idx,
|
||||||
|
scale);
|
||||||
|
} else {
|
||||||
|
constexpr int ty = 8 / tx;
|
||||||
|
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||||
|
dim3 nthrs(tx, ty, tz);
|
||||||
|
|
||||||
|
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||||
|
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||||
|
full_y_size, num_layers, layer_idx,
|
||||||
|
scale);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
static_assert(feat_in % (vec_size * 32) == 0 ||
|
||||||
|
feat_in % (vec_size * 16) == 0 ||
|
||||||
|
feat_in % (vec_size * 8) == 0);
|
||||||
|
|
||||||
|
if constexpr (feat_in % (vec_size * 32) == 0) {
|
||||||
|
constexpr int tx = 32;
|
||||||
|
constexpr int ty = 4;
|
||||||
|
|
||||||
|
dim3 nblks(feat_out, batch_size);
|
||||||
|
dim3 nthrs(tx, ty);
|
||||||
|
|
||||||
|
bgmv_shrink_kernel<feat_in, feat_out, vec_size, vec_size * sizeof(in_T),
|
||||||
|
vec_size * sizeof(W_T), tx, ty, tz>
|
||||||
|
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||||
|
full_y_size, num_layers, layer_idx,
|
||||||
|
scale);
|
||||||
|
} else if constexpr (feat_in % (vec_size / 2 * 32) == 0) {
|
||||||
|
constexpr int tx = 32;
|
||||||
|
constexpr int ty = 4;
|
||||||
|
|
||||||
|
dim3 nblks(feat_out, batch_size);
|
||||||
|
dim3 nthrs(tx, ty);
|
||||||
|
|
||||||
|
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
|
||||||
|
vec_size * sizeof(in_T) / 2,
|
||||||
|
vec_size * sizeof(W_T) / 2, tx, ty, tz>
|
||||||
|
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||||
|
full_y_size, num_layers, layer_idx,
|
||||||
|
scale);
|
||||||
|
} else if constexpr (feat_in % (vec_size / 2 * 16) == 0) {
|
||||||
|
constexpr int tx = 16;
|
||||||
|
constexpr int ty = 4;
|
||||||
|
|
||||||
|
dim3 nblks(feat_out, batch_size);
|
||||||
|
dim3 nthrs(tx, ty);
|
||||||
|
|
||||||
|
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
|
||||||
|
vec_size * sizeof(in_T) / 2,
|
||||||
|
vec_size * sizeof(W_T) / 2, tx, ty, tz>
|
||||||
|
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||||
|
full_y_size, num_layers, layer_idx,
|
||||||
|
scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \
|
||||||
|
template void bgmv_kernel<feat_in, feat_out>( \
|
||||||
|
out_T * __restrict__ Y, const in_T *__restrict__ X, \
|
||||||
|
const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \
|
||||||
|
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
|
||||||
|
int64_t num_layers, int64_t layer_idx, float scale);
|
||||||
|
|
||||||
|
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
|
||||||
|
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
|
||||||
|
INST_BGMV(wide, narrow, in_T, out_T, W_T)
|
||||||
27
csrc/punica/bgmv/generator.py
Normal file
27
csrc/punica/bgmv/generator.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
DTYPES = ["fp16", "bf16", "fp32"]
|
||||||
|
DTYPE_MAP = {
|
||||||
|
"fp16": "nv_half",
|
||||||
|
"bf16": "nv_bfloat16",
|
||||||
|
"fp32": "float",
|
||||||
|
}
|
||||||
|
|
||||||
|
TEMPLATE = """
|
||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
for input_dtype in DTYPES:
|
||||||
|
for output_dtype in DTYPES:
|
||||||
|
for weight_dtype in DTYPES:
|
||||||
|
if weight_dtype == "fp32":
|
||||||
|
# FP32 weights are not supported.
|
||||||
|
continue
|
||||||
|
kernel_definition = TEMPLATE.format(
|
||||||
|
input_dtype=DTYPE_MAP[input_dtype],
|
||||||
|
output_dtype=DTYPE_MAP[output_dtype],
|
||||||
|
weight_dtype=DTYPE_MAP[weight_dtype])
|
||||||
|
filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu"
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
f.write(kernel_definition)
|
||||||
1324
csrc/punica/bgmv/vec_dtypes.cuh
Normal file
1324
csrc/punica/bgmv/vec_dtypes.cuh
Normal file
File diff suppressed because it is too large
Load Diff
563
csrc/punica/punica_ops.cc
Normal file
563
csrc/punica/punica_ops.cc
Normal file
@@ -0,0 +1,563 @@
|
|||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "bgmv/bgmv_config.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
//====== utils ======
|
||||||
|
|
||||||
|
inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
|
||||||
|
const char *a_name, const char *b_name) {
|
||||||
|
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ",
|
||||||
|
a.dim(), " vs ", b.dim());
|
||||||
|
for (int i = 0; i < a.dim(); ++i) {
|
||||||
|
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name,
|
||||||
|
".size(", i, ")");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
|
||||||
|
return (uint32_t(a) << 16) | uint32_t(b);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||||
|
|
||||||
|
#define CHECK_CONTIGUOUS(x) \
|
||||||
|
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||||
|
|
||||||
|
#define CHECK_INPUT(x) \
|
||||||
|
CHECK_CUDA(x); \
|
||||||
|
CHECK_CONTIGUOUS(x)
|
||||||
|
|
||||||
|
#define CHECK_DIM(d, x) \
|
||||||
|
TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
||||||
|
|
||||||
|
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
|
||||||
|
|
||||||
|
#define CHECK_EQ(a, b) \
|
||||||
|
TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||||
|
|
||||||
|
//====== bgmv ======
|
||||||
|
|
||||||
|
template <typename in_T, typename out_T, typename W_T>
|
||||||
|
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
|
||||||
|
const int64_t *lora_indices,
|
||||||
|
uint16_t in_features, uint16_t out_features,
|
||||||
|
int64_t y_offset, int64_t full_y_size,
|
||||||
|
int64_t batch_size, int64_t num_layers,
|
||||||
|
int64_t layer_idx, float scale) {
|
||||||
|
switch (pack_u16(in_features, out_features)) {
|
||||||
|
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
|
||||||
|
case pack_u16(feat_in, feat_out): \
|
||||||
|
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
|
||||||
|
full_y_size, batch_size, num_layers, \
|
||||||
|
layer_idx, scale); \
|
||||||
|
break;
|
||||||
|
#define CASE(_in_T, _out_T, _W_T, narrow, wide) \
|
||||||
|
CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \
|
||||||
|
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
|
||||||
|
#undef CASE
|
||||||
|
#undef CASE_ONESIDE
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||||
|
torch::Tensor indicies, int64_t layer_idx, float scale) {
|
||||||
|
CHECK_INPUT(y);
|
||||||
|
CHECK_INPUT(x);
|
||||||
|
CHECK_INPUT(w);
|
||||||
|
CHECK_INPUT(indicies);
|
||||||
|
|
||||||
|
CHECK_DIM(2, y);
|
||||||
|
CHECK_DIM(2, x);
|
||||||
|
CHECK_DIM(4, w);
|
||||||
|
CHECK_DIM(1, indicies);
|
||||||
|
|
||||||
|
int64_t B = x.size(0);
|
||||||
|
int64_t h_in = x.size(1);
|
||||||
|
int64_t h_out = y.size(1);
|
||||||
|
int64_t num_layers = w.size(1);
|
||||||
|
CHECK_EQ(w.size(3), h_in);
|
||||||
|
CHECK_EQ(w.size(2), h_out);
|
||||||
|
CHECK_EQ(indicies.size(0), x.size(0));
|
||||||
|
CHECK_EQ(y.size(0), x.size(0));
|
||||||
|
bool ok = false;
|
||||||
|
if (h_in < 65536 && h_out < 65536) {
|
||||||
|
// TODO: See if we can get rid of this massive nested switch
|
||||||
|
switch (x.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (y.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (y.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (y.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
|
||||||
|
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
|
||||||
|
}
|
||||||
|
|
||||||
|
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||||
|
torch::Tensor indicies, int64_t layer_idx,
|
||||||
|
float scale, int64_t h_in, int64_t h_out,
|
||||||
|
int64_t y_offset) {
|
||||||
|
CHECK_INPUT(y);
|
||||||
|
CHECK_INPUT(x);
|
||||||
|
CHECK_INPUT(w);
|
||||||
|
CHECK_INPUT(indicies);
|
||||||
|
|
||||||
|
CHECK_DIM(2, y);
|
||||||
|
CHECK_DIM(2, x);
|
||||||
|
CHECK_DIM(4, w);
|
||||||
|
CHECK_DIM(1, indicies);
|
||||||
|
|
||||||
|
int64_t B = x.size(0);
|
||||||
|
int64_t num_layers = w.size(1);
|
||||||
|
int64_t full_y_size = y.size(1);
|
||||||
|
CHECK_EQ(w.size(3), h_in);
|
||||||
|
CHECK_EQ(w.size(2), h_out);
|
||||||
|
CHECK_EQ(indicies.size(0), x.size(0));
|
||||||
|
CHECK_EQ(y.size(0), x.size(0));
|
||||||
|
bool ok = false;
|
||||||
|
if (h_in < 65536 && h_out < 65536) {
|
||||||
|
// TODO: See if we can get rid of this massive nested switch
|
||||||
|
switch (x.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (y.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (y.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (y.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
|
||||||
|
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
//====== pybind ======
|
||||||
|
|
||||||
|
#define DEFINE_pybind(name) m.def(#name, &name, #name);
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
|
||||||
|
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
|
||||||
|
"dispatch_bgmv_low_level");
|
||||||
|
}
|
||||||
117
csrc/pybind.cpp
Normal file
117
csrc/pybind.cpp
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
#include "cache.h"
|
||||||
|
#include "cuda_utils.h"
|
||||||
|
#include "ops.h"
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
// vLLM custom ops
|
||||||
|
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
||||||
|
|
||||||
|
// Attention ops
|
||||||
|
ops.def(
|
||||||
|
"paged_attention_v1",
|
||||||
|
&paged_attention_v1,
|
||||||
|
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
||||||
|
ops.def(
|
||||||
|
"paged_attention_v2",
|
||||||
|
&paged_attention_v2,
|
||||||
|
"PagedAttention V2.");
|
||||||
|
|
||||||
|
// Activation ops
|
||||||
|
ops.def(
|
||||||
|
"silu_and_mul",
|
||||||
|
&silu_and_mul,
|
||||||
|
"Activation function used in SwiGLU.");
|
||||||
|
ops.def(
|
||||||
|
"gelu_and_mul",
|
||||||
|
&gelu_and_mul,
|
||||||
|
"Activation function used in GeGLU.");
|
||||||
|
ops.def(
|
||||||
|
"gelu_new",
|
||||||
|
&gelu_new,
|
||||||
|
"GELU implementation used in GPT-2.");
|
||||||
|
ops.def(
|
||||||
|
"gelu_fast",
|
||||||
|
&gelu_fast,
|
||||||
|
"Approximate GELU implementation.");
|
||||||
|
|
||||||
|
// Layernorm
|
||||||
|
ops.def(
|
||||||
|
"rms_norm",
|
||||||
|
&rms_norm,
|
||||||
|
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||||
|
|
||||||
|
ops.def(
|
||||||
|
"fused_add_rms_norm",
|
||||||
|
&fused_add_rms_norm,
|
||||||
|
"In-place fused Add and RMS Normalization");
|
||||||
|
|
||||||
|
// Rotary embedding
|
||||||
|
ops.def(
|
||||||
|
"rotary_embedding",
|
||||||
|
&rotary_embedding,
|
||||||
|
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||||
|
|
||||||
|
// Quantization ops
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||||
|
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
|
||||||
|
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
||||||
|
#endif
|
||||||
|
|
||||||
|
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
||||||
|
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
||||||
|
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||||
|
ops.def(
|
||||||
|
"moe_align_block_size",
|
||||||
|
&moe_align_block_size,
|
||||||
|
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
|
||||||
|
|
||||||
|
// Cache ops
|
||||||
|
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
||||||
|
cache_ops.def(
|
||||||
|
"swap_blocks",
|
||||||
|
&swap_blocks,
|
||||||
|
"Swap in (out) the cache blocks from src to dst");
|
||||||
|
cache_ops.def(
|
||||||
|
"copy_blocks",
|
||||||
|
©_blocks,
|
||||||
|
"Copy the cache blocks from src to dst");
|
||||||
|
cache_ops.def(
|
||||||
|
"reshape_and_cache",
|
||||||
|
&reshape_and_cache,
|
||||||
|
"Reshape the key and value tensors and cache them");
|
||||||
|
cache_ops.def(
|
||||||
|
"convert_fp8_e5m2",
|
||||||
|
&convert_fp8_e5m2,
|
||||||
|
"Convert the key and value cache to fp8_e5m2 data type");
|
||||||
|
|
||||||
|
// Cuda utils
|
||||||
|
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
|
||||||
|
cuda_utils.def(
|
||||||
|
"get_device_attribute",
|
||||||
|
&get_device_attribute,
|
||||||
|
"Gets the specified device attribute.");
|
||||||
|
|
||||||
|
cuda_utils.def(
|
||||||
|
"get_max_shared_memory_per_block_device_attribute",
|
||||||
|
&get_max_shared_memory_per_block_device_attribute,
|
||||||
|
"Gets the maximum shared memory per block device attribute.");
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
// Custom all-reduce kernels
|
||||||
|
pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce");
|
||||||
|
custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar");
|
||||||
|
custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar");
|
||||||
|
custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg");
|
||||||
|
custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg");
|
||||||
|
custom_ar.def("dispose", &dispose, "dispose");
|
||||||
|
custom_ar.def("meta_size", &meta_size, "meta_size");
|
||||||
|
custom_ar.def("register_buffer", ®ister_buffer, "register_buffer");
|
||||||
|
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta,
|
||||||
|
"get_graph_buffer_ipc_meta");
|
||||||
|
custom_ar.def("register_graph_buffers", ®ister_graph_buffers,
|
||||||
|
"register_graph_buffers");
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
torch::Tensor awq_gemm(
|
|
||||||
torch::Tensor _in_feats,
|
|
||||||
torch::Tensor _kernel,
|
|
||||||
torch::Tensor _scaling_factors,
|
|
||||||
torch::Tensor _zeros,
|
|
||||||
int split_k_iters);
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def(
|
|
||||||
"awq_gemm",
|
|
||||||
&awq_gemm,
|
|
||||||
"Quantized GEMM for AWQ");
|
|
||||||
}
|
|
||||||
@@ -27,35 +27,48 @@ __pack_half2(const half x, const half y) {
|
|||||||
return (v1 << 16) | v0;
|
return (v1 << 16) | v0;
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
template<int N>
|
||||||
|
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
||||||
|
int G,
|
||||||
|
int split_k_iters,
|
||||||
|
half* __restrict__ A,
|
||||||
|
int* __restrict__ B,
|
||||||
|
half* __restrict__ scaling_factors,
|
||||||
|
int* __restrict__ zeros,
|
||||||
|
int M,
|
||||||
|
int IC,
|
||||||
|
int OC,
|
||||||
|
half* __restrict__ C)
|
||||||
{
|
{
|
||||||
|
// Only support matrix n = 64 or 128
|
||||||
|
assert(N == 64 || N == 128);
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||||
assert(false);
|
assert(false);
|
||||||
#else
|
#else
|
||||||
static constexpr uint32_t ZERO = 0x0;
|
static constexpr uint32_t ZERO = 0x0;
|
||||||
float C_warp[32];
|
float C_warp[32];
|
||||||
__shared__ half A_shared[16 * (32 + 8)];
|
__shared__ half A_shared[16 * (32 + 8)];
|
||||||
__shared__ half B_shared[32 * (128 + 8)];
|
__shared__ half B_shared[32 * (N + 8)];
|
||||||
|
|
||||||
__shared__ half scaling_factors_shared[128];
|
__shared__ half scaling_factors_shared[N];
|
||||||
__shared__ half zeros_shared[128];
|
__shared__ half zeros_shared[N];
|
||||||
|
|
||||||
int j_factors1 = ((OC + 128 - 1) / 128);
|
int j_factors1 = ((OC + N - 1) / N);
|
||||||
int blockIdx_x = 0;
|
int blockIdx_x = 0;
|
||||||
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
||||||
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
||||||
|
|
||||||
half A_shared_warp[8];
|
half A_shared_warp[8];
|
||||||
half B_shared_warp[32];
|
half B_shared_warp[N / 4];
|
||||||
for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) {
|
for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) {
|
||||||
for (int i = 0; i < 8; ++i) {
|
for (int i = 0; i < 8; ++i) {
|
||||||
C_warp[(j_0_4_init * 8) + i] = 0.0;
|
C_warp[(j_0_4_init * 8) + i] = 0.0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static constexpr int row_stride_warp = 32 * 8 / 32;
|
static constexpr int row_stride_warp = 32 * 8 / 32;
|
||||||
static constexpr int row_stride = 2 * 32 * 8 / 128;
|
static constexpr int row_stride = 2 * 32 * 8 / N;
|
||||||
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128;
|
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
|
||||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||||
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||||
@@ -65,10 +78,10 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
|
|||||||
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
||||||
|
|
||||||
int* B_ptr = B
|
int* B_ptr = B
|
||||||
+ ((int)threadIdx.y) * (OC / 8) * 2
|
+ ((int)threadIdx.y) * (OC / 8) * (256 / N)
|
||||||
+ (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
|
+ (((int)threadIdx.x) / (N / 8)) * (OC / 8)
|
||||||
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
|
+ (((int)blockIdx_y) % j_factors1) * (N / 8)
|
||||||
+ (((int)threadIdx.x) % (128 / 8)) * 1;
|
+ (((int)threadIdx.x) % (N / 8)) * 1;
|
||||||
// Why * 1 in the above line?
|
// Why * 1 in the above line?
|
||||||
|
|
||||||
half* A_shared_ptr = A_shared
|
half* A_shared_ptr = A_shared
|
||||||
@@ -77,22 +90,22 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
|
|||||||
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
||||||
|
|
||||||
half* B_shared_ptr = B_shared
|
half* B_shared_ptr = B_shared
|
||||||
+ ((int)threadIdx.y) * (row_stride / 2) * (128 + 8)
|
+ ((int)threadIdx.y) * (row_stride / 2) * (N + 8)
|
||||||
+ (((int)threadIdx.x) / (128 / 8)) * (128 + 8)
|
+ (((int)threadIdx.x) / (N / 8)) * (N + 8)
|
||||||
+ (((int)threadIdx.x) % (128 / 8)) * 8;
|
+ (((int)threadIdx.x) % (N / 8)) * 8;
|
||||||
|
|
||||||
int* zeros_ptr = zeros
|
int* zeros_ptr = zeros
|
||||||
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
|
+ (((int)blockIdx_y) % j_factors1) * (N / 8)
|
||||||
+ ((int)threadIdx.x) % (128 / 8);
|
+ ((int)threadIdx.x) % (N / 8);
|
||||||
|
|
||||||
half* scaling_factors_ptr = scaling_factors
|
half* scaling_factors_ptr = scaling_factors
|
||||||
+ (((int)blockIdx_y) % j_factors1) * (128)
|
+ (((int)blockIdx_y) % j_factors1) * N
|
||||||
+ (((int)threadIdx.x) % (128 / 8)) * 8;
|
+ (((int)threadIdx.x) % (N / 8)) * 8;
|
||||||
|
|
||||||
half* C_ptr = C
|
half* C_ptr = C
|
||||||
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
||||||
+ (((int)blockIdx_y) % j_factors1) * 128
|
+ (((int)blockIdx_y) % j_factors1) * N
|
||||||
+ ((int)threadIdx.y) * 64
|
+ ((int)threadIdx.y) * (N / 2)
|
||||||
+ (((int)threadIdx.x) % 4) * 2;
|
+ (((int)threadIdx.x) % 4) * 2;
|
||||||
|
|
||||||
// preload s.f. and zeros
|
// preload s.f. and zeros
|
||||||
@@ -123,7 +136,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
|
|||||||
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||||
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||||
|
|
||||||
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
|
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
|
||||||
|
|
||||||
// B: 32 x 136 (128+8) float16
|
// B: 32 x 136 (128+8) float16
|
||||||
// each warp: 32 x 4
|
// each warp: 32 x 4
|
||||||
@@ -152,7 +165,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
// write back
|
// write back
|
||||||
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16;
|
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
@@ -174,13 +187,13 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) {
|
for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
|
||||||
{
|
{
|
||||||
unsigned int addr;
|
unsigned int addr;
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||||
: "=r"(addr)
|
: "=r"(addr)
|
||||||
: "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8))))
|
: "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||||
);
|
);
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||||
@@ -190,7 +203,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
|
for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||||
{
|
{
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
@@ -258,244 +271,115 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__global__ void __launch_bounds__(64) dequantize_weights(
|
||||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
int* __restrict__ B,
|
||||||
|
half* __restrict__ scaling_factors,
|
||||||
|
int* __restrict__ zeros,
|
||||||
|
half* __restrict__ C,
|
||||||
|
int G
|
||||||
|
)
|
||||||
{
|
{
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
int j_factors1 = 4;
|
||||||
assert(false);
|
int row_stride2 = 4;
|
||||||
#else
|
int split_k_iters = 1;
|
||||||
static constexpr uint32_t ZERO = 0x0;
|
static constexpr uint32_t ZERO = 0x0;
|
||||||
float C_warp[32];
|
half B_shared[32 * (128 + 8)];
|
||||||
__shared__ half A_shared[16 * (32 + 8)];
|
|
||||||
__shared__ half B_shared[32 * (64 + 8)];
|
|
||||||
|
|
||||||
__shared__ half scaling_factors_shared[64];
|
half* B_shared_ptr2 = B_shared;
|
||||||
__shared__ half zeros_shared[64];
|
|
||||||
|
|
||||||
int j_factors1 = ((OC + 64 - 1) / 64);
|
half B_shared_warp[32];
|
||||||
|
int OC = 512;
|
||||||
|
|
||||||
int blockIdx_x = 0;
|
int N = blockDim.x * gridDim.x; // 2
|
||||||
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
int col = (blockIdx.x * blockDim.x + threadIdx.x);
|
||||||
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
||||||
|
int index1 = 8 * col + 8 * row * N;
|
||||||
|
half* C_ptr2 = C + index1;
|
||||||
|
|
||||||
half A_shared_warp[8];
|
int index2 = col + row * N;
|
||||||
half B_shared_warp[16];
|
int* B_ptr2 = B + index2;
|
||||||
for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
|
|
||||||
for (int i = 0; i < 8; ++i) {
|
int index3 = col + (int)(row / G) * N;
|
||||||
C_warp[(j_0_4_init * 8) + i] = 0.0;
|
int* zeros_ptr2 = zeros + index3;
|
||||||
}
|
int index4 = 8 * col + (int)(row / G) * N * 8;
|
||||||
|
half* scaling_factors_ptr2 = scaling_factors + index4;
|
||||||
|
|
||||||
|
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2);
|
||||||
|
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||||
|
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2);
|
||||||
|
|
||||||
|
uint32_t B_loaded = *(uint32_t*)B_ptr2;
|
||||||
|
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||||
|
|
||||||
|
*(uint4*)B_shared_ptr2 = B_loaded_fp16;
|
||||||
|
|
||||||
|
for (int i = 0; i < 8; ++i) {
|
||||||
|
*(C_ptr2 + i) = B_shared[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
static constexpr int row_stride_warp = 32 * 8 / 32;
|
|
||||||
static constexpr int row_stride = 2 * 32 * 8 / 64;
|
|
||||||
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
|
|
||||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
|
||||||
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
|
||||||
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
|
||||||
|
|
||||||
half* A_ptr = A
|
|
||||||
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
|
||||||
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
|
||||||
|
|
||||||
int* B_ptr = B
|
|
||||||
+ ((int)threadIdx.y) * (OC / 8) * 4
|
|
||||||
+ (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
|
|
||||||
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
|
|
||||||
+ (((int)threadIdx.x) % (64 / 8)) * 1;
|
|
||||||
// Why * 1 in the above line?
|
|
||||||
|
|
||||||
half* A_shared_ptr = A_shared
|
|
||||||
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
|
||||||
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
|
||||||
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
|
||||||
|
|
||||||
half* B_shared_ptr = B_shared
|
|
||||||
+ ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
|
|
||||||
+ (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
|
|
||||||
+ (((int)threadIdx.x) % (64 / 8)) * 8;
|
|
||||||
|
|
||||||
int* zeros_ptr = zeros
|
|
||||||
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
|
|
||||||
+ ((int)threadIdx.x) % (64 / 8);
|
|
||||||
|
|
||||||
half* scaling_factors_ptr = scaling_factors
|
|
||||||
+ (((int)blockIdx_y) % j_factors1) * (64)
|
|
||||||
+ (((int)threadIdx.x) % (64 / 8)) * 8;
|
|
||||||
|
|
||||||
half* C_ptr = C
|
|
||||||
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
|
||||||
+ (((int)blockIdx_y) % j_factors1) * 64
|
|
||||||
+ ((int)threadIdx.y) * 32
|
|
||||||
+ (((int)threadIdx.x) % 4) * 2;
|
|
||||||
|
|
||||||
// preload s.f. and zeros
|
|
||||||
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
|
||||||
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
|
|
||||||
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
|
|
||||||
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
|
||||||
__syncthreads();
|
|
||||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
|
||||||
if (ld_A_flag)
|
|
||||||
{
|
|
||||||
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
|
||||||
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
|
||||||
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
|
||||||
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
|
||||||
/*
|
|
||||||
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
|
||||||
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
|
||||||
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
|
||||||
|
|
||||||
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
|
|
||||||
|
|
||||||
// B: 32 x 136 (128+8) float16
|
|
||||||
// each warp: 32 x 4
|
|
||||||
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
|
||||||
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
|
||||||
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
|
||||||
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
|
||||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
|
||||||
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
|
||||||
|
|
||||||
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
|
||||||
// - zero and * scale
|
|
||||||
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
|
||||||
/*
|
|
||||||
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
|
||||||
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
// write back
|
|
||||||
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1)
|
|
||||||
{
|
|
||||||
{
|
|
||||||
unsigned int addr;
|
|
||||||
__asm__ __volatile__(
|
|
||||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
|
||||||
: "=r"(addr)
|
|
||||||
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
|
||||||
);
|
|
||||||
__asm__ __volatile__(
|
|
||||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
|
||||||
"{%0, %1, %2, %3}, [%4];\n"
|
|
||||||
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
|
||||||
: "r"(addr)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0)
|
|
||||||
{
|
|
||||||
{
|
|
||||||
unsigned int addr;
|
|
||||||
__asm__ __volatile__(
|
|
||||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
|
||||||
: "=r"(addr)
|
|
||||||
: "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
|
|
||||||
);
|
|
||||||
__asm__ __volatile__(
|
|
||||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
|
||||||
"{%0, %1, %2, %3}, [%4];\n"
|
|
||||||
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
|
||||||
: "r"(addr)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4)
|
|
||||||
{
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
|
||||||
{
|
|
||||||
__asm__ __volatile__(
|
|
||||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
|
||||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
|
||||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
__asm__ __volatile__(
|
|
||||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
|
||||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
|
||||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
__asm__ __volatile__(
|
|
||||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
|
||||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
|
||||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
__asm__ __volatile__(
|
|
||||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
|
||||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
|
||||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
{
|
|
||||||
__asm__ __volatile__(
|
|
||||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
|
||||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
__asm__ __volatile__(
|
|
||||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
|
||||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Shang: Hoist loop invariance.
|
|
||||||
for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
|
|
||||||
for (int local_id = 0; local_id < 8; ++local_id) {
|
|
||||||
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
|
||||||
if (row_offset < M)
|
|
||||||
{
|
|
||||||
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace awq
|
} // namespace awq
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
|
torch::Tensor awq_dequantize(
|
||||||
|
torch::Tensor _kernel,
|
||||||
|
torch::Tensor _scaling_factors,
|
||||||
|
torch::Tensor _zeros,
|
||||||
|
int split_k_iters,
|
||||||
|
int thx,
|
||||||
|
int thy)
|
||||||
|
{
|
||||||
|
int in_c = _kernel.size(0);
|
||||||
|
int qout_c = _kernel.size(1);
|
||||||
|
int out_c = qout_c * 8;
|
||||||
|
int G = in_c / _scaling_factors.size(0);
|
||||||
|
|
||||||
|
int x_thread = thx;
|
||||||
|
int y_thread = thy;
|
||||||
|
|
||||||
|
int x_blocks = 1;
|
||||||
|
int y_blocks = 1;
|
||||||
|
if (thx==0) {
|
||||||
|
x_thread = qout_c;
|
||||||
|
}
|
||||||
|
if (thy==0) {
|
||||||
|
y_thread = in_c;
|
||||||
|
}
|
||||||
|
if (thx==0 && thy==0) {
|
||||||
|
x_thread = 8;
|
||||||
|
y_thread = 8;
|
||||||
|
x_blocks = (int)(qout_c / 8);
|
||||||
|
y_blocks = (int)(in_c / 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
|
||||||
|
|
||||||
|
auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device());
|
||||||
|
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);
|
||||||
|
|
||||||
|
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||||
|
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
|
||||||
|
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||||
|
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||||
|
|
||||||
|
dim3 num_blocks(x_blocks, y_blocks);
|
||||||
|
dim3 threads_per_block(x_thread, y_thread);
|
||||||
|
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
vllm::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||||
|
kernel, scaling_factors, zeros, de_kernel, G);
|
||||||
|
|
||||||
|
return _de_kernel;
|
||||||
|
}
|
||||||
|
|
||||||
// in_feats: M, IC [float16]
|
// in_feats: M, IC [float16]
|
||||||
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
|
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
|
||||||
// scaling_factors: IC // G, OC [float16]
|
// scaling_factors: IC // G, OC [float16]
|
||||||
@@ -542,8 +426,9 @@ torch::Tensor awq_gemm(
|
|||||||
// threadIdx.x: 32
|
// threadIdx.x: 32
|
||||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||||
dim3 threads_per_block(32, 2);
|
dim3 threads_per_block(32, 2);
|
||||||
vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block, 0, stream>>>(
|
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<<num_blocks, threads_per_block, 0, stream>>>(
|
||||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
|
||||||
|
num_out_channels, out_feats);
|
||||||
}
|
}
|
||||||
else if (num_out_channels % 64 == 0)
|
else if (num_out_channels % 64 == 0)
|
||||||
{
|
{
|
||||||
@@ -553,8 +438,9 @@ torch::Tensor awq_gemm(
|
|||||||
// threadIdx.x: 32
|
// threadIdx.x: 32
|
||||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||||
dim3 threads_per_block(32, 2);
|
dim3 threads_per_block(32, 2);
|
||||||
vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block, 0, stream>>>(
|
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<<num_blocks, threads_per_block, 0, stream>>>(
|
||||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
|
||||||
|
num_out_channels, out_feats);
|
||||||
}
|
}
|
||||||
return _out_feats.sum(0);
|
return _out_feats.sum(0);
|
||||||
}
|
}
|
||||||
|
|||||||
277
csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh
Normal file
277
csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <float.h>
|
||||||
|
#include <type_traits>
|
||||||
|
#include "../../attention/attention_dtypes.h"
|
||||||
|
#include "../../attention/dtype_float32.cuh"
|
||||||
|
#include "../../attention/dtype_float16.cuh"
|
||||||
|
#include "../../attention/dtype_bfloat16.cuh"
|
||||||
|
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
#ifdef ENABLE_FP8_E5M2
|
||||||
|
namespace fp8_e5m2_unscaled {
|
||||||
|
|
||||||
|
template<typename Tout, typename Tin>
|
||||||
|
__inline__ __device__ Tout vec_conversion(const Tin& x)
|
||||||
|
{
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> half
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
|
||||||
|
{
|
||||||
|
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
|
||||||
|
return res.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> half2
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
|
||||||
|
{
|
||||||
|
union {
|
||||||
|
uint16_t u16[2];
|
||||||
|
uint32_t u32;
|
||||||
|
} tmp;
|
||||||
|
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2);
|
||||||
|
tmp.u16[0] = res.x;
|
||||||
|
tmp.u16[1] = res.y;
|
||||||
|
return tmp.u32;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> half2x2
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
|
||||||
|
{
|
||||||
|
union {
|
||||||
|
uint2 u32x2;
|
||||||
|
uint32_t u32[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
||||||
|
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
||||||
|
return tmp.u32x2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> half2x4
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
|
||||||
|
{
|
||||||
|
union {
|
||||||
|
uint4 u64x2;
|
||||||
|
uint2 u64[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
||||||
|
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
||||||
|
return tmp.u64x2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> __nv_bfloat16
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
|
||||||
|
{
|
||||||
|
// Note there is no direct convert function from fp8 to bf16.
|
||||||
|
// fp8 -> half
|
||||||
|
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
|
||||||
|
// half -> float -> bf16
|
||||||
|
float tmp = half_to_float(res.x);
|
||||||
|
return __float2bfloat16(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> __nv_bfloat162
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
|
||||||
|
{
|
||||||
|
__nv_bfloat162 res;
|
||||||
|
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
||||||
|
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> bf16_4_t
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
|
||||||
|
{
|
||||||
|
bf16_4_t res;
|
||||||
|
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
||||||
|
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> bf16_8_t
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
|
||||||
|
{
|
||||||
|
bf16_4_t tmp1, tmp2;
|
||||||
|
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
||||||
|
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
||||||
|
bf16_8_t res;
|
||||||
|
res.x = tmp1.x;
|
||||||
|
res.y = tmp1.y;
|
||||||
|
res.z = tmp2.x;
|
||||||
|
res.w = tmp2.y;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> float
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
|
||||||
|
{
|
||||||
|
// fp8 -> half
|
||||||
|
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a);
|
||||||
|
// half -> float
|
||||||
|
return half_to_float(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> float2
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
|
||||||
|
{
|
||||||
|
// fp8x2 -> half2
|
||||||
|
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a);
|
||||||
|
// half2 -> float2
|
||||||
|
return half2_to_float2(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> float4
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
|
||||||
|
{
|
||||||
|
Float4_ res;
|
||||||
|
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
||||||
|
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> float8
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
|
||||||
|
{
|
||||||
|
Float4_ tmp1, tmp2;
|
||||||
|
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
||||||
|
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
||||||
|
Float8_ res;
|
||||||
|
res.x = tmp1.x;
|
||||||
|
res.y = tmp1.y;
|
||||||
|
res.z = tmp2.x;
|
||||||
|
res.w = tmp2.y;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// half -> fp8
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
|
||||||
|
{
|
||||||
|
__half_raw tmp;
|
||||||
|
tmp.x = a;
|
||||||
|
__nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2);
|
||||||
|
return (uint8_t)res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// bf16 -> fp8
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
|
||||||
|
{
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2);
|
||||||
|
return (uint8_t)res;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// float -> fp8
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
|
||||||
|
{
|
||||||
|
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2);
|
||||||
|
return (uint8_t)res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> float4
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
|
||||||
|
{
|
||||||
|
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||||
|
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
|
||||||
|
{
|
||||||
|
union {
|
||||||
|
half2 float16;
|
||||||
|
uint32_t uint32;
|
||||||
|
};
|
||||||
|
|
||||||
|
float16 = __float22half2_rn(a);
|
||||||
|
return uint32;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
|
||||||
|
{
|
||||||
|
uint2 b;
|
||||||
|
float2 val;
|
||||||
|
val.x = a.x.x;
|
||||||
|
val.y = a.x.y;
|
||||||
|
b.x = vec_conversion<uint32_t, float2>(val);
|
||||||
|
|
||||||
|
val.x = a.y.x;
|
||||||
|
val.y = a.y.y;
|
||||||
|
b.y = vec_conversion<uint32_t, float2>(val);
|
||||||
|
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
|
||||||
|
{
|
||||||
|
float4 b;
|
||||||
|
b.x = a.x.x;
|
||||||
|
b.y = a.x.y;
|
||||||
|
b.z = a.y.x;
|
||||||
|
b.w = a.y.y;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
|
||||||
|
{
|
||||||
|
uint4 b;
|
||||||
|
b.x = vec_conversion<uint32_t, float2>(a.x);
|
||||||
|
b.y = vec_conversion<uint32_t, float2>(a.y);
|
||||||
|
b.z = vec_conversion<uint32_t, float2>(a.z);
|
||||||
|
b.w = vec_conversion<uint32_t, float2>(a.w);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) {
|
||||||
|
__nv_bfloat162 b;
|
||||||
|
from_float(b, a);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) {
|
||||||
|
bf16_4_t b;
|
||||||
|
from_float(b, a);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) {
|
||||||
|
bf16_8_t b;
|
||||||
|
from_float(b, a);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fp8_e5m2_unscaled
|
||||||
|
#endif // ENABLE_FP8_E5M2
|
||||||
|
} // namespace vllm
|
||||||
64
csrc/quantization/gptq/compat.cuh
Normal file
64
csrc/quantization/gptq/compat.cuh
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
/*
|
||||||
|
Copied from https://github.com/turboderp/exllamav2
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef _compat_cuh
|
||||||
|
#define _compat_cuh
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace gptq {
|
||||||
|
// atomicAdd for half types, to support CC < 7.x
|
||||||
|
|
||||||
|
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
||||||
|
{
|
||||||
|
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||||
|
unsigned int old = *address_as_ui;
|
||||||
|
unsigned int assumed;
|
||||||
|
|
||||||
|
do
|
||||||
|
{
|
||||||
|
assumed = old;
|
||||||
|
__half_raw hsum;
|
||||||
|
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||||
|
half tmpres = __hadd(hsum, val);
|
||||||
|
hsum = __half_raw(tmpres);
|
||||||
|
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
||||||
|
old = atomicCAS(address_as_ui, assumed, old);
|
||||||
|
}
|
||||||
|
while (assumed != old);
|
||||||
|
}
|
||||||
|
|
||||||
|
// atomicAdd for half2 types
|
||||||
|
|
||||||
|
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
||||||
|
{
|
||||||
|
unsigned int* address_as_ui = (unsigned int*)address;
|
||||||
|
unsigned int old = *address_as_ui;
|
||||||
|
unsigned int assumed;
|
||||||
|
do
|
||||||
|
{
|
||||||
|
assumed = old;
|
||||||
|
half2 old_val = *((half2*)&old);
|
||||||
|
half2 new_val = __hadd2(old_val, val);
|
||||||
|
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||||
|
}
|
||||||
|
while (assumed != old);
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||||
|
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||||
|
|
||||||
|
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||||
|
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace gptq
|
||||||
|
} // namespace vllm
|
||||||
|
#endif
|
||||||
274
csrc/quantization/gptq/matrix_view.cuh
Normal file
274
csrc/quantization/gptq/matrix_view.cuh
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
/*
|
||||||
|
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef _matrix_view_cuh
|
||||||
|
#define _matrix_view_cuh
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
#include "qdq_util.cuh"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace gptq {
|
||||||
|
|
||||||
|
class MatrixView_half
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
const half* data;
|
||||||
|
const int height;
|
||||||
|
const int width;
|
||||||
|
|
||||||
|
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
||||||
|
: data(data), height(height), width(width)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||||
|
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||||
|
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
||||||
|
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
||||||
|
|
||||||
|
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
|
||||||
|
{
|
||||||
|
half2* ptr = (half2*) item_ptr(row, column);
|
||||||
|
half2 i01 = ptr[0];
|
||||||
|
half2 i23 = ptr[1];
|
||||||
|
items[0] = __low2half(i01);
|
||||||
|
items[1] = __high2half(i01);
|
||||||
|
items[2] = __low2half(i23);
|
||||||
|
items[3] = __high2half(i23);
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
|
||||||
|
{
|
||||||
|
half2* ptr = (half2*)item_ptr(row, column);
|
||||||
|
half2 i01 = ptr[0];
|
||||||
|
half2 i23 = ptr[1];
|
||||||
|
items[0] = __half2float(__low2half(i01));
|
||||||
|
items[1] = __half2float(__high2half(i01));
|
||||||
|
items[2] = __half2float(__low2half(i23));
|
||||||
|
items[3] = __half2float(__high2half(i23));
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
|
||||||
|
{
|
||||||
|
half2* ptr = (half2*)item_ptr(row, column);
|
||||||
|
half2 i01 = ptr[0];
|
||||||
|
half2 i23 = ptr[1];
|
||||||
|
items[0] = __half2half2(__low2half(i01));
|
||||||
|
items[1] = __half2half2(__high2half(i01));
|
||||||
|
items[2] = __half2half2(__low2half(i23));
|
||||||
|
items[3] = __half2half2(__high2half(i23));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MatrixView_half_rw
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
half* data;
|
||||||
|
const int height;
|
||||||
|
const int width;
|
||||||
|
|
||||||
|
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
||||||
|
: data(data), height(height), width(width)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||||
|
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||||
|
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
||||||
|
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
||||||
|
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
||||||
|
|
||||||
|
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
|
||||||
|
{
|
||||||
|
half2 v01 = __halves2half2(v0, v1);
|
||||||
|
half2 v23 = __halves2half2(v2, v3);
|
||||||
|
half2* ptr = (half2*) item_ptr(row, column);
|
||||||
|
ptr[0] = v01;
|
||||||
|
ptr[1] = v23;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MatrixView_q4_row
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
const uint32_t* data;
|
||||||
|
const int height;
|
||||||
|
const int width;
|
||||||
|
|
||||||
|
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
||||||
|
: data(data), height(height), width(width)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
__device__ __forceinline__ int item(int row, int column) const
|
||||||
|
{
|
||||||
|
int shift = (column & 0x07) * 4;
|
||||||
|
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
||||||
|
{
|
||||||
|
int shift = (column & 0x07) * 4;
|
||||||
|
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||||
|
items[0] = d & 0x0f;
|
||||||
|
items[1] = (d >> 4) & 0x0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||||
|
{
|
||||||
|
int shift = (column & 0x07) * 4;
|
||||||
|
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||||
|
items[0] = d & 0x0f;
|
||||||
|
items[1] = (d >> 4) & 0x0f;
|
||||||
|
items[2] = (d >> 8) & 0x0f;
|
||||||
|
items[3] = (d >> 12) & 0x0f;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MatrixView_q4_column
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
const uint32_t* data;
|
||||||
|
const int height;
|
||||||
|
const int width;
|
||||||
|
|
||||||
|
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
|
||||||
|
: data(data), height(height), width(width)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
__device__ __forceinline__ int item(int row, int column) const
|
||||||
|
{
|
||||||
|
int shift = (row & 0x07) * 4;
|
||||||
|
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
|
||||||
|
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
|
||||||
|
};
|
||||||
|
|
||||||
|
class MatrixView_q2_row
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
const uint32_t* data;
|
||||||
|
const int height;
|
||||||
|
const int width;
|
||||||
|
|
||||||
|
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width)
|
||||||
|
: data(data), height(height), width(width)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
__device__ __forceinline__ int item(int row, int column) const
|
||||||
|
{
|
||||||
|
int shift = (column & 0x0f) * 2;
|
||||||
|
return (data[row * width / 16 + column / 16] >> shift) & 0x03;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
||||||
|
{
|
||||||
|
int shift = (column & 0x0f) * 2;
|
||||||
|
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
||||||
|
items[0] = d & 0x03;
|
||||||
|
items[1] = (d >> 2) & 0x03;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||||
|
{
|
||||||
|
int shift = (column & 0x0f) * 2;
|
||||||
|
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
||||||
|
items[0] = d & 0x03;
|
||||||
|
items[1] = (d >> 2) & 0x03;
|
||||||
|
items[2] = (d >> 4) & 0x03;
|
||||||
|
items[3] = (d >> 6) & 0x03;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MatrixView_q3_row
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
const uint32_t* data;
|
||||||
|
const int height;
|
||||||
|
const int width;
|
||||||
|
|
||||||
|
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width)
|
||||||
|
: data(data), height(height), width(width)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
__device__ __forceinline__ int item(int row, int column) const
|
||||||
|
{
|
||||||
|
int z_w = column * 3 / 32;
|
||||||
|
int z_mod = column & 0x1f;
|
||||||
|
|
||||||
|
if (z_mod == 10) {
|
||||||
|
return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
|
||||||
|
} else if (z_mod == 21) {
|
||||||
|
return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
|
||||||
|
} else if (z_mod < 10) {
|
||||||
|
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
|
||||||
|
} else if (z_mod < 21) {
|
||||||
|
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07;
|
||||||
|
} else {
|
||||||
|
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||||
|
{
|
||||||
|
int shift = (column & 0x1f);
|
||||||
|
uint32_t d;
|
||||||
|
if (shift <= 4) {
|
||||||
|
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
|
||||||
|
} else if (shift == 8) {
|
||||||
|
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
|
||||||
|
} else if (shift <= 16) {
|
||||||
|
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
|
||||||
|
} else if (shift == 20) {
|
||||||
|
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
|
||||||
|
} else {
|
||||||
|
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
|
||||||
|
}
|
||||||
|
items[0] = d & 0x07;
|
||||||
|
items[1] = (d >> 3) & 0x07;
|
||||||
|
items[2] = (d >> 6) & 0x07;
|
||||||
|
items[3] = (d >> 9) & 0x07;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MatrixView_q8_row
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
const uint32_t* data;
|
||||||
|
const int height;
|
||||||
|
const int width;
|
||||||
|
|
||||||
|
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width)
|
||||||
|
: data(data), height(height), width(width)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
__device__ __forceinline__ int item(int row, int column) const
|
||||||
|
{
|
||||||
|
int shift = (column & 0x03) * 8;
|
||||||
|
return (data[row * width / 4 + column / 4] >> shift) & 0xff;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
||||||
|
{
|
||||||
|
int shift = (column & 0x03) * 8;
|
||||||
|
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
||||||
|
items[0] = d & 0xff;
|
||||||
|
items[1] = (d >> 8) & 0xff;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||||
|
{
|
||||||
|
int shift = (column & 0x03) * 2;
|
||||||
|
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
||||||
|
items[0] = d & 0xff;
|
||||||
|
items[1] = (d >> 8) & 0xff;
|
||||||
|
items[2] = (d >> 16) & 0xff;
|
||||||
|
items[3] = (d >> 24) & 0xff;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gptq
|
||||||
|
} // namespace vllm
|
||||||
|
#endif
|
||||||
2075
csrc/quantization/gptq/q_gemm.cu
Normal file
2075
csrc/quantization/gptq/q_gemm.cu
Normal file
File diff suppressed because it is too large
Load Diff
87
csrc/quantization/gptq/qdq_2.cuh
Normal file
87
csrc/quantization/gptq/qdq_2.cuh
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
/*
|
||||||
|
Copied from https://github.com/turboderp/exllamav2
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef _qdq_2_cuh
|
||||||
|
#define _qdq_2_cuh
|
||||||
|
|
||||||
|
#include "qdq_util.cuh"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace gptq {
|
||||||
|
|
||||||
|
// Permutation:
|
||||||
|
//
|
||||||
|
// ffddbb99 77553311 eeccaa88 66442200
|
||||||
|
|
||||||
|
__forceinline__ __device__ void shuffle_2bit_16
|
||||||
|
(
|
||||||
|
uint32_t* q,
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
uint32_t qa = q[0];
|
||||||
|
uint32_t qb = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 8; i++)
|
||||||
|
{
|
||||||
|
uint32_t qa0 = qa & 0x03;
|
||||||
|
uint32_t qa1 = (qa & 0x0c) >> 2;
|
||||||
|
qa >>= 4;
|
||||||
|
qb |= (qa1 << (i * 2 + 16));
|
||||||
|
qb |= (qa0 << (i * 2));
|
||||||
|
}
|
||||||
|
q[0] = qb;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_2bit_16
|
||||||
|
(
|
||||||
|
const uint32_t q_0,
|
||||||
|
half2 (&dq)[8],
|
||||||
|
int stride,
|
||||||
|
const uint32_t zero
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const uint32_t c0 = 0x64006400;
|
||||||
|
const half y4_ = __float2half_rn(1.0f / 4.0f);
|
||||||
|
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||||
|
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||||
|
const half2 y4 = __halves2half2(y4_, y4_);
|
||||||
|
const half2 y16 = __halves2half2(y16_, y16_);
|
||||||
|
const half2 y64 = __halves2half2(y64_, y64_);
|
||||||
|
|
||||||
|
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||||
|
const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero));
|
||||||
|
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||||
|
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
|
||||||
|
const half2 z1 = __half2half2(z1_.as_half);
|
||||||
|
const half2 z4 = __half2half2(z4_);
|
||||||
|
const half2 z16 = __half2half2(z16_);
|
||||||
|
const half2 z64 = __half2half2(z64_);
|
||||||
|
|
||||||
|
uint32_t qa = q_0;
|
||||||
|
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||||
|
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
|
||||||
|
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
|
||||||
|
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
|
||||||
|
qa >>= 8;
|
||||||
|
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
|
||||||
|
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
|
||||||
|
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
|
||||||
|
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
|
||||||
|
|
||||||
|
dq[0] = __hadd2(q0.as_half2, z1);
|
||||||
|
dq[1] = __hfma2(q1.as_half2, y4, z4);
|
||||||
|
dq[2] = __hfma2(q2.as_half2, y16, z16);
|
||||||
|
dq[3] = __hfma2(q3.as_half2, y64, z64);
|
||||||
|
dq[4] = __hadd2(q4.as_half2, z1);
|
||||||
|
dq[5] = __hfma2(q5.as_half2, y4, z4);
|
||||||
|
dq[6] = __hfma2(q6.as_half2, y16, z16);
|
||||||
|
dq[7] = __hfma2(q7.as_half2, y64, z64);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gptq
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
#endif
|
||||||
141
csrc/quantization/gptq/qdq_3.cuh
Normal file
141
csrc/quantization/gptq/qdq_3.cuh
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
#ifndef _qdq_3_cuh
|
||||||
|
#define _qdq_3_cuh
|
||||||
|
|
||||||
|
#include "qdq_util.cuh"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace gptq {
|
||||||
|
// Permutation:
|
||||||
|
//
|
||||||
|
// v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||||
|
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||||
|
// vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||||
|
|
||||||
|
__forceinline__ __device__ void shuffle_3bit_32
|
||||||
|
(
|
||||||
|
uint32_t* q,
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
uint32_t qa = q[0 * stride];
|
||||||
|
uint32_t qb = q[1 * stride];
|
||||||
|
uint32_t qc = q[2 * stride];
|
||||||
|
|
||||||
|
// qa: aa999888 77766655 54443332 22111000
|
||||||
|
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
|
||||||
|
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
|
||||||
|
|
||||||
|
uint32_t qd = qc >> 26;
|
||||||
|
qc <<= 4;
|
||||||
|
qc |= qb >> 28;
|
||||||
|
qb <<= 2;
|
||||||
|
qb |= qa >> 30;
|
||||||
|
|
||||||
|
// qa: ..999888 77766655 54443332 22111000
|
||||||
|
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
|
||||||
|
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
|
||||||
|
// qd: vvvuuu
|
||||||
|
|
||||||
|
uint32_t za = 0;
|
||||||
|
uint32_t zb = 0;
|
||||||
|
uint32_t zc = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
|
||||||
|
for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
|
||||||
|
for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
|
||||||
|
|
||||||
|
// za: 9997775 55333111 8886664 44222000
|
||||||
|
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
|
||||||
|
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
|
||||||
|
// qd: vvvuuu
|
||||||
|
|
||||||
|
za |= ((qd & 0x01) >> 0) << 15;
|
||||||
|
zb |= ((qd & 0x02) >> 1) << 15;
|
||||||
|
zc |= ((qd & 0x04) >> 2) << 15;
|
||||||
|
za |= ((qd & 0x08) >> 3) << 31;
|
||||||
|
zb |= ((qd & 0x10) >> 4) << 31;
|
||||||
|
zc |= ((qd & 0x20) >> 5) << 31;
|
||||||
|
|
||||||
|
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||||
|
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||||
|
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||||
|
|
||||||
|
q[0 * stride] = za;
|
||||||
|
q[1 * stride] = zb;
|
||||||
|
q[2 * stride] = zc;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_3bit_32
|
||||||
|
(
|
||||||
|
const uint32_t q_0,
|
||||||
|
const uint32_t q_1,
|
||||||
|
const uint32_t q_2,
|
||||||
|
half2 (&dq)[16],
|
||||||
|
int stride,
|
||||||
|
const uint32_t zero
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const uint32_t c0 = 0x64006400;
|
||||||
|
const half y8_ = __float2half_rn(1.0f / 8.0f);
|
||||||
|
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||||
|
const half2 y8 = __halves2half2(y8_, y8_);
|
||||||
|
const half2 y64 = __halves2half2(y64_, y64_);
|
||||||
|
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||||
|
const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero));
|
||||||
|
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
|
||||||
|
const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half);
|
||||||
|
const half2 z8 = __halves2half2(z8_, z8_);
|
||||||
|
const half2 z64 = __halves2half2(z64_, z64_);
|
||||||
|
|
||||||
|
uint32_t qa = q_0;
|
||||||
|
uint32_t qb = q_1;
|
||||||
|
uint32_t qc = q_2;
|
||||||
|
|
||||||
|
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||||
|
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
|
||||||
|
qa >>= 6;
|
||||||
|
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||||
|
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
|
||||||
|
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
|
||||||
|
qa >>= 9;
|
||||||
|
qa &= 0x00010001;
|
||||||
|
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
|
||||||
|
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
|
||||||
|
qb >>= 6;
|
||||||
|
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
|
||||||
|
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
|
||||||
|
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
|
||||||
|
qb >>= 8;
|
||||||
|
qb &= 0x00020002;
|
||||||
|
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
|
||||||
|
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
|
||||||
|
qc >>= 6;
|
||||||
|
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
|
||||||
|
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
|
||||||
|
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
|
||||||
|
qc >>= 7;
|
||||||
|
qc &= 0x00040004;
|
||||||
|
half2_uint32 q15((qa | qb | qc) | c0);
|
||||||
|
|
||||||
|
dq[ 0] = __hadd2( q0.as_half2, z1);
|
||||||
|
dq[ 1] = __hfma2( q1.as_half2, y8, z8);
|
||||||
|
dq[ 2] = __hadd2( q2.as_half2, z1);
|
||||||
|
dq[ 3] = __hfma2( q3.as_half2, y8, z8);
|
||||||
|
dq[ 4] = __hfma2( q4.as_half2, y64, z64);
|
||||||
|
dq[ 5] = __hadd2( q5.as_half2, z1);
|
||||||
|
dq[ 6] = __hfma2( q6.as_half2, y8, z8);
|
||||||
|
dq[ 7] = __hadd2( q7.as_half2, z1);
|
||||||
|
dq[ 8] = __hfma2( q8.as_half2, y8, z8);
|
||||||
|
dq[ 9] = __hfma2( q9.as_half2, y64, z64);
|
||||||
|
dq[10] = __hadd2(q10.as_half2, z1);
|
||||||
|
dq[11] = __hfma2(q11.as_half2, y8, z8);
|
||||||
|
dq[12] = __hadd2(q12.as_half2, z1);
|
||||||
|
dq[13] = __hfma2(q13.as_half2, y8, z8);
|
||||||
|
dq[14] = __hfma2(q14.as_half2, y64, z64);
|
||||||
|
dq[15] = __hadd2(q15.as_half2, z1);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gptq
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
#endif
|
||||||
147
csrc/quantization/gptq/qdq_4.cuh
Normal file
147
csrc/quantization/gptq/qdq_4.cuh
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
/*
|
||||||
|
Copied from https://github.com/turboderp/exllamav2
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef _qdq_4_cuh
|
||||||
|
#define _qdq_4_cuh
|
||||||
|
|
||||||
|
#include "qdq_util.cuh"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace gptq {
|
||||||
|
// Permutation:
|
||||||
|
//
|
||||||
|
// 77775555 33331111 66664444 22220000
|
||||||
|
|
||||||
|
__forceinline__ __device__ void shuffle_4bit_8
|
||||||
|
(
|
||||||
|
uint32_t* q,
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
uint32_t qa = q[0];
|
||||||
|
uint32_t qb = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++)
|
||||||
|
{
|
||||||
|
uint32_t qa0 = qa & 0x0f;
|
||||||
|
uint32_t qa1 = (qa & 0xf0) >> 4;
|
||||||
|
qa >>= 8;
|
||||||
|
qb |= (qa1 << (i * 4 + 16));
|
||||||
|
qb |= (qa0 << (i * 4));
|
||||||
|
}
|
||||||
|
q[0] = qb;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_4bit_8
|
||||||
|
(
|
||||||
|
const uint32_t q_0,
|
||||||
|
half2 (&dq)[4],
|
||||||
|
int stride,
|
||||||
|
const uint32_t zero
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const uint32_t c0 = 0x64006400;
|
||||||
|
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||||
|
const half2 y16 = __halves2half2(y16_, y16_);
|
||||||
|
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||||
|
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||||
|
const half2 z1 = __half2half2(z1_.as_half);
|
||||||
|
const half2 z16 = __half2half2(z16_);
|
||||||
|
|
||||||
|
uint32_t qa = q_0;
|
||||||
|
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||||
|
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
|
||||||
|
qa >>= 8;
|
||||||
|
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||||
|
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
|
||||||
|
|
||||||
|
dq[0] = __hadd2(q0.as_half2, z1);
|
||||||
|
dq[1] = __hfma2(q1.as_half2, y16, z16);
|
||||||
|
dq[2] = __hadd2(q2.as_half2, z1);
|
||||||
|
dq[3] = __hfma2(q3.as_half2, y16, z16);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
||||||
|
(
|
||||||
|
const uint32_t zero,
|
||||||
|
const half scale,
|
||||||
|
half2 (&z1z16)[2],
|
||||||
|
half2 (&y1y16)[2]
|
||||||
|
)
|
||||||
|
{
|
||||||
|
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||||
|
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||||
|
|
||||||
|
half2 scale2 = __half2half2(scale);
|
||||||
|
|
||||||
|
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
|
||||||
|
z1z16[1] = __hmul2(scale2, __half2half2(z16));
|
||||||
|
|
||||||
|
const half y1 = __float2half_rn(1.0f);
|
||||||
|
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||||
|
|
||||||
|
y1y16[0] = __hmul2(scale2, __half2half2(y1));
|
||||||
|
y1y16[1] = __hmul2(scale2, __half2half2(y16));
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
||||||
|
(
|
||||||
|
const uint32_t zero,
|
||||||
|
half2(&z1z16)[2],
|
||||||
|
half2(&y1y16)[2]
|
||||||
|
)
|
||||||
|
{
|
||||||
|
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||||
|
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||||
|
|
||||||
|
z1z16[0] = __half2half2(z1.as_half);
|
||||||
|
z1z16[1] = __half2half2(z16);
|
||||||
|
|
||||||
|
const half y1 = __float2half_rn(1.0f);
|
||||||
|
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||||
|
|
||||||
|
y1y16[0] = __half2half2(y1);
|
||||||
|
y1y16[1] = __half2half2(y16);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_4bit_8_gptq
|
||||||
|
(
|
||||||
|
const uint32_t q_0,
|
||||||
|
half2 (&dq)[4],
|
||||||
|
half2 (&z1z16)[2],
|
||||||
|
half2 (&y1y16)[2],
|
||||||
|
int stride,
|
||||||
|
bool scaled
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const uint32_t c0 = 0x64006400;
|
||||||
|
|
||||||
|
uint32_t qa = q_0;
|
||||||
|
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
|
||||||
|
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
|
||||||
|
qa >>= 8;
|
||||||
|
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
|
||||||
|
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
|
||||||
|
|
||||||
|
if (scaled)
|
||||||
|
{
|
||||||
|
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
|
||||||
|
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
|
||||||
|
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
|
||||||
|
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
|
||||||
|
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
|
||||||
|
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
|
||||||
|
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace gptq
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
#endif
|
||||||
40
csrc/quantization/gptq/qdq_8.cuh
Normal file
40
csrc/quantization/gptq/qdq_8.cuh
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
/*
|
||||||
|
Copied from https://github.com/turboderp/exllamav2
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef _qdq_8_cuh
|
||||||
|
#define _qdq_8_cuh
|
||||||
|
|
||||||
|
#include "qdq_util.cuh"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace gptq {
|
||||||
|
|
||||||
|
__forceinline__ __device__ void shuffle_8bit_4
|
||||||
|
(
|
||||||
|
uint32_t* q,
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_8bit_8
|
||||||
|
(
|
||||||
|
const uint32_t q_0,
|
||||||
|
const uint32_t q_1,
|
||||||
|
half2 (&dq)[4],
|
||||||
|
int stride,
|
||||||
|
const uint32_t zero
|
||||||
|
)
|
||||||
|
{
|
||||||
|
half dqh[8];
|
||||||
|
for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), zero);
|
||||||
|
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero);
|
||||||
|
|
||||||
|
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gptq
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
#endif
|
||||||
60
csrc/quantization/gptq/qdq_util.cuh
Normal file
60
csrc/quantization/gptq/qdq_util.cuh
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
/*
|
||||||
|
Copied from https://github.com/turboderp/exllamav2
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef _qdq_util_cuh
|
||||||
|
#define _qdq_util_cuh
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace gptq {
|
||||||
|
|
||||||
|
union half2_uint32
|
||||||
|
{
|
||||||
|
uint32_t as_uint32;
|
||||||
|
half2 as_half2;
|
||||||
|
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
||||||
|
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
union half_uint16
|
||||||
|
{
|
||||||
|
uint16_t as_uint16;
|
||||||
|
half as_half;
|
||||||
|
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
||||||
|
__device__ half_uint16(half val) : as_half(val) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Max_scale premultiplied by 1/256
|
||||||
|
|
||||||
|
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
|
||||||
|
{
|
||||||
|
int qs_i = qs + 1;
|
||||||
|
half qs_h = __int2half_rn(qs_i * qs_i);
|
||||||
|
qs_h = __hmul(qs_h, max_scale);
|
||||||
|
return qs_h;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
|
||||||
|
{
|
||||||
|
return __hmul(__int2half_rn(q - qzero), scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
|
||||||
|
{
|
||||||
|
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
|
||||||
|
return __int2half_rn(q - qzero);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
|
||||||
|
{
|
||||||
|
return (int)((q >> shift) & mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
|
||||||
|
{
|
||||||
|
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gptq
|
||||||
|
} // namespace vllm
|
||||||
|
#endif
|
||||||
209
csrc/quantization/marlin/LICENSE
Normal file
209
csrc/quantization/marlin/LICENSE
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
Contains code from https://github.com/IST-DASLab/marlin
|
||||||
|
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright {yyyy} {name of copyright owner}
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
||||||
|
------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
This product bundles various third-party components under other open source licenses.
|
||||||
|
This section summarizes those components and their licenses. See licenses/
|
||||||
|
for text of these licenses.
|
||||||
1145
csrc/quantization/marlin/marlin_cuda_kernel.cu
Normal file
1145
csrc/quantization/marlin/marlin_cuda_kernel.cu
Normal file
File diff suppressed because it is too large
Load Diff
225
csrc/quantization/squeezellm/quant_cuda_kernel.cu
Normal file
225
csrc/quantization/squeezellm/quant_cuda_kernel.cu
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
#include <torch/all.h>
|
||||||
|
#include <torch/python.h>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
// half-tensor
|
||||||
|
#include <c10/cuda/CUDAStream.h>
|
||||||
|
#include <ATen/cuda/CUDATensorMethods.cuh>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
|
#define BLOCKWIDTH 128
|
||||||
|
#define BLOCKHEIGHT4 16
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace squeezellm {
|
||||||
|
|
||||||
|
__device__ inline unsigned int as_unsigned(int i) {
|
||||||
|
return *reinterpret_cast<unsigned int*>(&i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4-bit matvec kernel (LUT-based)
|
||||||
|
__global__ void NUQ4MatMulKernel(
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
const half2* __restrict__ vec,
|
||||||
|
#else
|
||||||
|
const __half2* __restrict__ vec,
|
||||||
|
#endif
|
||||||
|
const int* __restrict__ mat,
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
half2* __restrict__ mul,
|
||||||
|
#else
|
||||||
|
float2* __restrict__ mul,
|
||||||
|
#endif
|
||||||
|
const __half* __restrict__ lookup_table,
|
||||||
|
int height,
|
||||||
|
int width,
|
||||||
|
int batch,
|
||||||
|
int vec_height
|
||||||
|
) {
|
||||||
|
|
||||||
|
const int blockwidth2 = BLOCKWIDTH / 2;
|
||||||
|
|
||||||
|
int row = BLOCKHEIGHT4 * blockIdx.x;
|
||||||
|
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
__shared__ half2 blockvec[blockwidth2];
|
||||||
|
#else
|
||||||
|
__shared__ __half2 blockvec[blockwidth2];
|
||||||
|
#endif
|
||||||
|
|
||||||
|
__shared__ __half deq2[16][BLOCKWIDTH];
|
||||||
|
int off = threadIdx.x;
|
||||||
|
int column_offset = col * 16;
|
||||||
|
for (int val = 0; val < 16; val += 1) {
|
||||||
|
int lut_index = column_offset + val;
|
||||||
|
deq2[val][off] = lookup_table[lut_index];
|
||||||
|
}
|
||||||
|
|
||||||
|
__half res;
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
half2 res2;
|
||||||
|
half2 tmp2;
|
||||||
|
#else
|
||||||
|
__half2 res2;
|
||||||
|
__half2 tmp2;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
int i;
|
||||||
|
int k;
|
||||||
|
|
||||||
|
unsigned int tmp1;
|
||||||
|
unsigned int lut_index1, lut_index2;
|
||||||
|
|
||||||
|
for (int b = 0; b < batch; ++b){
|
||||||
|
i = width * row + col;
|
||||||
|
res = __int2half_rd(0);
|
||||||
|
k = 0;
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
if (threadIdx.x < blockwidth2)
|
||||||
|
blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x];
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
while (k < blockwidth2) {
|
||||||
|
tmp1 = as_unsigned(mat[i]);
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
res2 = {};
|
||||||
|
tmp2 = {};
|
||||||
|
#else
|
||||||
|
res2.x = __half_as_ushort(__float2half(0));
|
||||||
|
res2.y = __half_as_ushort(__float2half(0));
|
||||||
|
tmp2.x = __half_as_ushort(__float2half(0));
|
||||||
|
tmp2.y = __half_as_ushort(__float2half(0));
|
||||||
|
#endif
|
||||||
|
|
||||||
|
lut_index1 = tmp1 & 0xF;
|
||||||
|
lut_index2 = (tmp1 >> 4) & 0xF;
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
tmp2.x = deq2[lut_index1][off];
|
||||||
|
tmp2.y = deq2[lut_index2][off];
|
||||||
|
#else
|
||||||
|
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||||
|
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||||
|
#endif
|
||||||
|
res2 = __hfma2(tmp2, blockvec[k + 0], res2);
|
||||||
|
|
||||||
|
lut_index1 = (tmp1 >> 8) & 0xF;
|
||||||
|
lut_index2 = (tmp1 >> 12) & 0xF;
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
tmp2.x = deq2[lut_index1][off];
|
||||||
|
tmp2.y = deq2[lut_index2][off];
|
||||||
|
#else
|
||||||
|
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||||
|
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||||
|
#endif
|
||||||
|
res2 = __hfma2(tmp2, blockvec[k + 1], res2);
|
||||||
|
|
||||||
|
lut_index1 = (tmp1 >> 16) & 0xF;
|
||||||
|
lut_index2 = (tmp1 >> 20) & 0xF;
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
tmp2.x = deq2[lut_index1][off];
|
||||||
|
tmp2.y = deq2[lut_index2][off];
|
||||||
|
#else
|
||||||
|
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||||
|
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||||
|
#endif
|
||||||
|
res2 = __hfma2(tmp2, blockvec[k + 2], res2);
|
||||||
|
|
||||||
|
lut_index1 = (tmp1 >> 24) & 0xF;
|
||||||
|
lut_index2 = (tmp1 >> 28) & 0xF;
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
tmp2.x = deq2[lut_index1][off];
|
||||||
|
tmp2.y = deq2[lut_index2][off];
|
||||||
|
#else
|
||||||
|
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||||
|
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||||
|
#endif
|
||||||
|
res2 = __hfma2(tmp2, blockvec[k + 3], res2);
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
res = __hadd(__hadd(res2.x, res2.y), res);
|
||||||
|
#else
|
||||||
|
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
i += width;
|
||||||
|
k += 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// col%2 -> only set one of the two values
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
half2 res3 = {};
|
||||||
|
if (col % 2 == 0) {
|
||||||
|
res3.x = res;
|
||||||
|
} else {
|
||||||
|
res3.y = res;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
__half2 res3;
|
||||||
|
res3.x = __half_as_ushort(__float2half(0));
|
||||||
|
res3.y = __half_as_ushort(__float2half(0));
|
||||||
|
if (col % 2 == 0) {
|
||||||
|
res3.x = __half_as_ushort(res);
|
||||||
|
} else {
|
||||||
|
res3.y = __half_as_ushort(res);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
atomicAdd(&mul[b * width / 2 + col / 2], res3);
|
||||||
|
#else
|
||||||
|
int tmp_addr = b * width / 2 + col / 2;
|
||||||
|
atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
|
||||||
|
atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace squeezellm
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
// 4-bit matvec kernel (LUT-based)
|
||||||
|
void squeezellm_gemm(
|
||||||
|
torch::Tensor vec,
|
||||||
|
torch::Tensor mat,
|
||||||
|
torch::Tensor mul,
|
||||||
|
torch::Tensor lookup_table
|
||||||
|
) {
|
||||||
|
int height = mat.size(0);
|
||||||
|
int width = mat.size(1);
|
||||||
|
|
||||||
|
int batch = vec.size(0);
|
||||||
|
int vec_height = vec.size(1);
|
||||||
|
|
||||||
|
dim3 blocks(
|
||||||
|
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||||
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
|
||||||
|
);
|
||||||
|
dim3 threads(BLOCKWIDTH);
|
||||||
|
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
(half2*) vec.data<at::Half>(),
|
||||||
|
#else
|
||||||
|
(__half2*) vec.data_ptr<at::Half>(),
|
||||||
|
#endif
|
||||||
|
mat.data_ptr<int>(),
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
(half2*) mul.data<at::Half>(),
|
||||||
|
(__half*) lookup_table.data<at::Half>(),
|
||||||
|
#else
|
||||||
|
(float2*) mul.data_ptr<float>(),
|
||||||
|
(__half*) lookup_table.data_ptr<at::Half>(),
|
||||||
|
#endif
|
||||||
|
height, width, batch, vec_height
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef BLOCKWIDTH
|
||||||
|
#undef BLOCKHEIGHT4
|
||||||
@@ -17,13 +17,15 @@
|
|||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "cuda_compat.h"
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__inline__ __device__ T warpReduceSum(T val) {
|
__inline__ __device__ T warpReduceSum(T val) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1)
|
for (int mask = 16; mask > 0; mask >>= 1)
|
||||||
val += __shfl_xor_sync(0xffffffff, val, mask, 32);
|
val += VLLM_SHFL_XOR_SYNC(val, mask);
|
||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,11 +9,15 @@
|
|||||||
# If extensions (or modules to document with autodoc) are in another directory,
|
# If extensions (or modules to document with autodoc) are in another directory,
|
||||||
# add these directories to sys.path here. If the directory is relative to the
|
# add these directories to sys.path here. If the directory is relative to the
|
||||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||||
#
|
|
||||||
# import os
|
|
||||||
# import sys
|
|
||||||
# sys.path.insert(0, os.path.abspath('.'))
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from sphinx.ext import autodoc
|
||||||
|
import logging
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join('..', '..')))
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
@@ -21,7 +25,6 @@ project = 'vLLM'
|
|||||||
copyright = '2023, vLLM Team'
|
copyright = '2023, vLLM Team'
|
||||||
author = 'the vLLM Team'
|
author = 'the vLLM Team'
|
||||||
|
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
# Add any Sphinx extension module names here, as strings. They can be
|
# Add any Sphinx extension module names here, as strings. They can be
|
||||||
@@ -32,6 +35,8 @@ extensions = [
|
|||||||
"sphinx.ext.viewcode",
|
"sphinx.ext.viewcode",
|
||||||
"sphinx.ext.intersphinx",
|
"sphinx.ext.intersphinx",
|
||||||
"sphinx_copybutton",
|
"sphinx_copybutton",
|
||||||
|
"sphinx.ext.autodoc",
|
||||||
|
"sphinx.ext.autosummary",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Add any paths that contain templates here, relative to this directory.
|
# Add any paths that contain templates here, relative to this directory.
|
||||||
@@ -55,7 +60,6 @@ html_title = project
|
|||||||
html_theme = 'sphinx_book_theme'
|
html_theme = 'sphinx_book_theme'
|
||||||
html_logo = 'assets/logos/vllm-logo-text-light.png'
|
html_logo = 'assets/logos/vllm-logo-text-light.png'
|
||||||
html_theme_options = {
|
html_theme_options = {
|
||||||
'logo_only': True,
|
|
||||||
'path_to_docs': 'docs/source',
|
'path_to_docs': 'docs/source',
|
||||||
'repository_url': 'https://github.com/vllm-project/vllm',
|
'repository_url': 'https://github.com/vllm-project/vllm',
|
||||||
'use_repository_button': True,
|
'use_repository_button': True,
|
||||||
@@ -64,4 +68,31 @@ html_theme_options = {
|
|||||||
# Add any paths that contain custom static files (such as style sheets) here,
|
# Add any paths that contain custom static files (such as style sheets) here,
|
||||||
# relative to this directory. They are copied after the builtin static files,
|
# relative to this directory. They are copied after the builtin static files,
|
||||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||||
html_static_path = ['_static']
|
# html_static_path = ['_static']
|
||||||
|
|
||||||
|
# Mock out external dependencies here.
|
||||||
|
autodoc_mock_imports = [
|
||||||
|
"torch", "transformers", "psutil", "prometheus_client", "sentencepiece",
|
||||||
|
"vllm.cuda_utils", "vllm._C"
|
||||||
|
]
|
||||||
|
|
||||||
|
for mock_target in autodoc_mock_imports:
|
||||||
|
if mock_target in sys.modules:
|
||||||
|
logger.info(
|
||||||
|
f"Potentially problematic mock target ({mock_target}) found; "
|
||||||
|
"autodoc_mock_imports cannot mock modules that have already "
|
||||||
|
"been loaded into sys.modules when the sphinx build starts.")
|
||||||
|
|
||||||
|
|
||||||
|
class MockedClassDocumenter(autodoc.ClassDocumenter):
|
||||||
|
"""Remove note about base class when a class is derived from object."""
|
||||||
|
|
||||||
|
def add_line(self, line: str, source: str, *lineno: int) -> None:
|
||||||
|
if line == " Bases: :py:class:`object`":
|
||||||
|
return
|
||||||
|
super().add_line(line, source, *lineno)
|
||||||
|
|
||||||
|
|
||||||
|
autodoc.ClassDocumenter = MockedClassDocumenter
|
||||||
|
|
||||||
|
navigation_with_keys = False
|
||||||
|
|||||||
7
docs/source/dev/engine/async_llm_engine.rst
Normal file
7
docs/source/dev/engine/async_llm_engine.rst
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
|
||||||
|
AsyncLLMEngine
|
||||||
|
=================================
|
||||||
|
|
||||||
|
.. autoclass:: vllm.engine.async_llm_engine.AsyncLLMEngine
|
||||||
|
:members: generate, abort
|
||||||
|
:show-inheritance:
|
||||||
13
docs/source/dev/engine/engine_index.rst
Normal file
13
docs/source/dev/engine/engine_index.rst
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
vLLM Engine
|
||||||
|
=================================
|
||||||
|
|
||||||
|
.. automodule:: vllm.engine
|
||||||
|
.. currentmodule:: vllm.engine
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 2
|
||||||
|
:caption: Engines
|
||||||
|
|
||||||
|
llm_engine
|
||||||
|
async_llm_engine
|
||||||
|
|
||||||
6
docs/source/dev/engine/llm_engine.rst
Normal file
6
docs/source/dev/engine/llm_engine.rst
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
LLMEngine
|
||||||
|
=================================
|
||||||
|
|
||||||
|
.. autoclass:: vllm.engine.llm_engine.LLMEngine
|
||||||
|
:members: add_request, abort_request, step, _init_cache
|
||||||
|
:show-inheritance:
|
||||||
172
docs/source/getting_started/amd-installation.rst
Normal file
172
docs/source/getting_started/amd-installation.rst
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
.. _installation_rocm:
|
||||||
|
|
||||||
|
Installation with ROCm
|
||||||
|
======================
|
||||||
|
|
||||||
|
vLLM 0.2.4 onwards supports model inferencing and serving on AMD GPUs with ROCm.
|
||||||
|
At the moment AWQ quantization is not supported in ROCm, but SqueezeLLM quantization has been ported.
|
||||||
|
Data types currently supported in ROCm are FP16 and BF16.
|
||||||
|
|
||||||
|
Requirements
|
||||||
|
------------
|
||||||
|
|
||||||
|
* OS: Linux
|
||||||
|
* Python: 3.8 -- 3.11
|
||||||
|
* GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100)
|
||||||
|
* Pytorch 2.0.1/2.1.1/2.2
|
||||||
|
* ROCm 5.7 (Verified on python 3.10) or ROCm 6.0 (Verified on python 3.9)
|
||||||
|
|
||||||
|
Installation options:
|
||||||
|
|
||||||
|
#. :ref:`(Recommended) Quick start with vLLM pre-installed in Docker Image <quick_start_docker_rocm>`
|
||||||
|
#. :ref:`Build from source <build_from_source_rocm>`
|
||||||
|
#. :ref:`Build from source with docker <build_from_source_docker_rocm>`
|
||||||
|
|
||||||
|
.. _quick_start_docker_rocm:
|
||||||
|
|
||||||
|
(Recommended) Option 1: Quick start with vLLM pre-installed in Docker Image
|
||||||
|
---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
This option is for ROCm 5.7 only:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.4
|
||||||
|
$ docker run -it \
|
||||||
|
--network=host \
|
||||||
|
--group-add=video \
|
||||||
|
--ipc=host \
|
||||||
|
--cap-add=SYS_PTRACE \
|
||||||
|
--security-opt seccomp=unconfined \
|
||||||
|
--device /dev/kfd \
|
||||||
|
--device /dev/dri \
|
||||||
|
-v <path/to/model>:/app/model \
|
||||||
|
embeddedllminfo/vllm-rocm \
|
||||||
|
bash
|
||||||
|
|
||||||
|
|
||||||
|
.. _build_from_source_rocm:
|
||||||
|
|
||||||
|
Option 2: Build from source
|
||||||
|
---------------------------
|
||||||
|
|
||||||
|
You can build and install vLLM from source:
|
||||||
|
|
||||||
|
Below instruction is for ROCm 5.7 only.
|
||||||
|
At the time of this documentation update, PyTorch on ROCm 6.0 wheel is not yet available on the PyTorch website.
|
||||||
|
|
||||||
|
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
|
||||||
|
|
||||||
|
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
|
||||||
|
- `Pytorch <https://pytorch.org/>`_
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ pip install torch==2.2.0.dev20231206+rocm5.7 --index-url https://download.pytorch.org/whl/nightly/rocm5.7 # tested version
|
||||||
|
|
||||||
|
|
||||||
|
1. Install `flash attention for ROCm <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm>`_
|
||||||
|
|
||||||
|
Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm#amd-gpurocm-support>`_
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
- If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly.
|
||||||
|
- If you fail to install `ROCmSoftwarePlatform/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`.
|
||||||
|
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
|
||||||
|
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
|
||||||
|
|
||||||
|
2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ pip install xformers==0.0.23 --no-deps
|
||||||
|
$ bash patch_xformers.rocm.sh
|
||||||
|
|
||||||
|
3. Build vLLM.
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ cd vllm
|
||||||
|
$ pip install -U -r requirements-rocm.txt
|
||||||
|
$ python setup.py install # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation
|
||||||
|
|
||||||
|
|
||||||
|
.. _build_from_source_docker_rocm:
|
||||||
|
|
||||||
|
Option 3: Build from source with docker
|
||||||
|
-----------------------------------------------------
|
||||||
|
|
||||||
|
You can build and install vLLM from source:
|
||||||
|
|
||||||
|
Build a docker image from `Dockerfile.rocm`, and launch a docker container.
|
||||||
|
|
||||||
|
The `Dokerfile.rocm` is designed to support both ROCm 5.7 and ROCm 6.0 and later versions. It provides flexibility to customize the build of docker image using the following arguments:
|
||||||
|
|
||||||
|
* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`
|
||||||
|
* `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942`
|
||||||
|
* `FA_BRANCH`: specifies the branch used to build the flash-attention in `ROCmSoftwarePlatform's flash-attention repo <https://github.com/ROCmSoftwarePlatform/flash-attention>`_. The default is `3d2b6f5`
|
||||||
|
* `BUILD_FA`: specifies whether to build flash-attention. For `Radeon RX 7900 series (gfx1100) <https://rocm.docs.amd.com/projects/radeon/en/latest/index.html>`_, this should be set to 0 before flash-attention supports this target.
|
||||||
|
|
||||||
|
Their values can be passed in when running ``docker build`` with ``--build-arg`` options.
|
||||||
|
|
||||||
|
For example, to build docker image for vllm on ROCm 5.7, you can run:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ docker build --build-arg BASE_IMAGE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \
|
||||||
|
-f Dockerfile.rocm -t vllm-rocm .
|
||||||
|
|
||||||
|
To build vllm on ROCm 6.0, you can use the default:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ docker build -f Dockerfile.rocm -t vllm-rocm .
|
||||||
|
$ docker run -it \
|
||||||
|
--network=host \
|
||||||
|
--group-add=video \
|
||||||
|
--ipc=host \
|
||||||
|
--cap-add=SYS_PTRACE \
|
||||||
|
--security-opt seccomp=unconfined \
|
||||||
|
--device /dev/kfd \
|
||||||
|
--device /dev/dri \
|
||||||
|
-v <path/to/model>:/app/model \
|
||||||
|
vllm-rocm \
|
||||||
|
bash
|
||||||
|
|
||||||
|
Alternatively, if you plan to install vLLM-ROCm on a local machine or start from a fresh docker image (e.g. rocm/pytorch), you can follow the steps below:
|
||||||
|
|
||||||
|
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
|
||||||
|
|
||||||
|
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
|
||||||
|
- `Pytorch <https://pytorch.org/>`_
|
||||||
|
- `hipBLAS <https://rocm.docs.amd.com/projects/hipBLAS/en/latest/install.html>`_
|
||||||
|
|
||||||
|
1. Install `flash attention for ROCm <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm>`_
|
||||||
|
|
||||||
|
Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm#amd-gpurocm-support>`_
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
- If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly.
|
||||||
|
- If you fail to install `ROCmSoftwarePlatform/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`.
|
||||||
|
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
|
||||||
|
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
|
||||||
|
|
||||||
|
2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ pip install xformers==0.0.23 --no-deps
|
||||||
|
$ bash patch_xformers.rocm.sh
|
||||||
|
|
||||||
|
3. Build vLLM.
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ cd vllm
|
||||||
|
$ pip install -U -r requirements-rocm.txt
|
||||||
|
$ python setup.py install # This may take 5-10 minutes.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
- You may need to turn on the ``--enforce-eager`` flag if you experience process hang when running the `benchmark_thoughput.py` script to test your installation.
|
||||||
|
|
||||||
@@ -3,14 +3,14 @@
|
|||||||
Installation
|
Installation
|
||||||
============
|
============
|
||||||
|
|
||||||
vLLM is a Python library that also contains pre-compiled C++ and CUDA (11.8) binaries.
|
vLLM is a Python library that also contains pre-compiled C++ and CUDA (12.1) binaries.
|
||||||
|
|
||||||
Requirements
|
Requirements
|
||||||
------------
|
------------
|
||||||
|
|
||||||
* OS: Linux
|
* OS: Linux
|
||||||
* Python: 3.8 -- 3.11
|
* Python: 3.8 -- 3.11
|
||||||
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.)
|
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, H100, etc.)
|
||||||
|
|
||||||
Install with pip
|
Install with pip
|
||||||
----------------
|
----------------
|
||||||
@@ -20,12 +20,32 @@ You can install vLLM using pip:
|
|||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
$ # (Optional) Create a new conda environment.
|
$ # (Optional) Create a new conda environment.
|
||||||
$ conda create -n myenv python=3.8 -y
|
$ conda create -n myenv python=3.9 -y
|
||||||
$ conda activate myenv
|
$ conda activate myenv
|
||||||
|
|
||||||
$ # Install vLLM.
|
$ # Install vLLM with CUDA 12.1.
|
||||||
$ pip install vllm
|
$ pip install vllm
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
As of now, vLLM's binaries are compiled on CUDA 12.1 by default.
|
||||||
|
However, you can install vLLM with CUDA 11.8 by running:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ # Install vLLM with CUDA 11.8.
|
||||||
|
$ export VLLM_VERSION=0.2.4
|
||||||
|
$ export PYTHON_VERSION=39
|
||||||
|
$ pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl
|
||||||
|
|
||||||
|
$ # Re-install PyTorch with CUDA 11.8.
|
||||||
|
$ pip uninstall torch -y
|
||||||
|
$ pip install torch --upgrade --index-url https://download.pytorch.org/whl/cu118
|
||||||
|
|
||||||
|
$ # Re-install xFormers with CUDA 11.8.
|
||||||
|
$ pip uninstall xformers -y
|
||||||
|
$ pip install --upgrade xformers --index-url https://download.pytorch.org/whl/cu118
|
||||||
|
|
||||||
|
|
||||||
.. _build_from_source:
|
.. _build_from_source:
|
||||||
|
|
||||||
@@ -45,6 +65,15 @@ You can also build and install vLLM from source:
|
|||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
$ # Pull the Docker image with CUDA 11.8.
|
|
||||||
$ # Use `--ipc=host` to make sure the shared memory is large enough.
|
$ # Use `--ipc=host` to make sure the shared memory is large enough.
|
||||||
$ docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:22.12-py3
|
$ docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:23.10-py3
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If you are developing the C++ backend of vLLM, consider building vLLM with
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ python setup.py develop
|
||||||
|
|
||||||
|
since it will give you incremental builds. The downside is that this method
|
||||||
|
is `deprecated by setuptools <https://github.com/pypa/setuptools/issues/917>`_.
|
||||||
|
|||||||
@@ -11,6 +11,14 @@ This guide shows how to use vLLM to:
|
|||||||
|
|
||||||
Be sure to complete the :ref:`installation instructions <installation>` before continuing with this guide.
|
Be sure to complete the :ref:`installation instructions <installation>` before continuing with this guide.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
By default, vLLM downloads model from `HuggingFace <https://huggingface.co/>`_. If you would like to use models from `ModelScope <https://www.modelscope.cn>`_ in the following examples, please set the environment variable:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
export VLLM_USE_MODELSCOPE=True
|
||||||
|
|
||||||
Offline Batched Inference
|
Offline Batched Inference
|
||||||
-------------------------
|
-------------------------
|
||||||
|
|
||||||
@@ -55,38 +63,11 @@ Call ``llm.generate`` to generate the outputs. It adds the input prompts to vLLM
|
|||||||
|
|
||||||
The code example can also be found in `examples/offline_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference.py>`_.
|
The code example can also be found in `examples/offline_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference.py>`_.
|
||||||
|
|
||||||
|
|
||||||
API Server
|
|
||||||
----------
|
|
||||||
|
|
||||||
vLLM can be deployed as an LLM service. We provide an example `FastAPI <https://fastapi.tiangolo.com/>`_ server. Check `vllm/entrypoints/api_server.py <https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/api_server.py>`_ for the server implementation. The server uses ``AsyncLLMEngine`` class to support asynchronous processing of incoming requests.
|
|
||||||
|
|
||||||
Start the server:
|
|
||||||
|
|
||||||
.. code-block:: console
|
|
||||||
|
|
||||||
$ python -m vllm.entrypoints.api_server
|
|
||||||
|
|
||||||
By default, this command starts the server at ``http://localhost:8000`` with the OPT-125M model.
|
|
||||||
|
|
||||||
Query the model in shell:
|
|
||||||
|
|
||||||
.. code-block:: console
|
|
||||||
|
|
||||||
$ curl http://localhost:8000/generate \
|
|
||||||
$ -d '{
|
|
||||||
$ "prompt": "San Francisco is a",
|
|
||||||
$ "use_beam_search": true,
|
|
||||||
$ "n": 4,
|
|
||||||
$ "temperature": 0
|
|
||||||
$ }'
|
|
||||||
|
|
||||||
See `examples/api_client.py <https://github.com/vllm-project/vllm/blob/main/examples/api_client.py>`_ for a more detailed client example.
|
|
||||||
|
|
||||||
OpenAI-Compatible Server
|
OpenAI-Compatible Server
|
||||||
------------------------
|
------------------------
|
||||||
|
|
||||||
vLLM can be deployed as a server that mimics the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API.
|
vLLM can be deployed as a server that implements the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API.
|
||||||
|
By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the command below) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_, `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_, and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints.
|
||||||
|
|
||||||
Start the server:
|
Start the server:
|
||||||
|
|
||||||
@@ -95,7 +76,13 @@ Start the server:
|
|||||||
$ python -m vllm.entrypoints.openai.api_server \
|
$ python -m vllm.entrypoints.openai.api_server \
|
||||||
$ --model facebook/opt-125m
|
$ --model facebook/opt-125m
|
||||||
|
|
||||||
By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_ and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints.
|
By default, the server uses a predefined chat template stored in the tokenizer. You can override this template by using the ``--chat-template`` argument:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ python -m vllm.entrypoints.openai.api_server \
|
||||||
|
$ --model facebook/opt-125m \
|
||||||
|
$ --chat-template ./examples/template_chatml.jinja
|
||||||
|
|
||||||
This server can be queried in the same format as OpenAI API. For example, list the models:
|
This server can be queried in the same format as OpenAI API. For example, list the models:
|
||||||
|
|
||||||
@@ -103,6 +90,11 @@ This server can be queried in the same format as OpenAI API. For example, list t
|
|||||||
|
|
||||||
$ curl http://localhost:8000/v1/models
|
$ curl http://localhost:8000/v1/models
|
||||||
|
|
||||||
|
You can pass in the argument ``--api-key`` or environment variable ``VLLM_API_KEY`` to enable the server to check for API key in the header.
|
||||||
|
|
||||||
|
Using OpenAI Completions API with vLLM
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
Query the model with input prompts:
|
Query the model with input prompts:
|
||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
@@ -120,12 +112,65 @@ Since this server is compatible with OpenAI API, you can use it as a drop-in rep
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
import openai
|
from openai import OpenAI
|
||||||
|
|
||||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||||
openai.api_key = "EMPTY"
|
openai_api_key = "EMPTY"
|
||||||
openai.api_base = "http://localhost:8000/v1"
|
openai_api_base = "http://localhost:8000/v1"
|
||||||
completion = openai.Completion.create(model="facebook/opt-125m",
|
client = OpenAI(
|
||||||
|
api_key=openai_api_key,
|
||||||
|
base_url=openai_api_base,
|
||||||
|
)
|
||||||
|
completion = client.completions.create(model="facebook/opt-125m",
|
||||||
prompt="San Francisco is a")
|
prompt="San Francisco is a")
|
||||||
print("Completion result:", completion)
|
print("Completion result:", completion)
|
||||||
|
|
||||||
For a more detailed client example, refer to `examples/openai_completion_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`_.
|
For a more detailed client example, refer to `examples/openai_completion_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`_.
|
||||||
|
|
||||||
|
Using OpenAI Chat API with vLLM
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
The vLLM server is designed to support the OpenAI Chat API, allowing you to engage in dynamic conversations with the model. The chat interface is a more interactive way to communicate with the model, allowing back-and-forth exchanges that can be stored in the chat history. This is useful for tasks that require context or more detailed explanations.
|
||||||
|
|
||||||
|
Querying the model using OpenAI Chat API:
|
||||||
|
|
||||||
|
You can use the `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_ endpoint to communicate with the model in a chat-like interface:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ curl http://localhost:8000/v1/chat/completions \
|
||||||
|
$ -H "Content-Type: application/json" \
|
||||||
|
$ -d '{
|
||||||
|
$ "model": "facebook/opt-125m",
|
||||||
|
$ "messages": [
|
||||||
|
$ {"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
$ {"role": "user", "content": "Who won the world series in 2020?"}
|
||||||
|
$ ]
|
||||||
|
$ }'
|
||||||
|
|
||||||
|
Python Client Example:
|
||||||
|
|
||||||
|
Using the `openai` python package, you can also communicate with the model in a chat-like manner:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
# Set OpenAI's API key and API base to use vLLM's API server.
|
||||||
|
openai_api_key = "EMPTY"
|
||||||
|
openai_api_base = "http://localhost:8000/v1"
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=openai_api_key,
|
||||||
|
base_url=openai_api_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_response = client.chat.completions.create(
|
||||||
|
model="facebook/opt-125m",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Tell me a joke."},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
print("Chat response:", chat_response)
|
||||||
|
|
||||||
|
For more in-depth examples and advanced features of the chat API, you can refer to the official OpenAI documentation.
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ vLLM is fast with:
|
|||||||
* State-of-the-art serving throughput
|
* State-of-the-art serving throughput
|
||||||
* Efficient management of attention key and value memory with **PagedAttention**
|
* Efficient management of attention key and value memory with **PagedAttention**
|
||||||
* Continuous batching of incoming requests
|
* Continuous batching of incoming requests
|
||||||
|
* Fast model execution with CUDA/HIP graph
|
||||||
|
* Quantization: `GPTQ <https://arxiv.org/abs/2210.17323>`_, `AWQ <https://arxiv.org/abs/2306.00978>`_, `SqueezeLLM <https://arxiv.org/abs/2306.07629>`_, FP8 KV Cache
|
||||||
* Optimized CUDA kernels
|
* Optimized CUDA kernels
|
||||||
|
|
||||||
vLLM is flexible and easy to use with:
|
vLLM is flexible and easy to use with:
|
||||||
@@ -39,6 +41,9 @@ vLLM is flexible and easy to use with:
|
|||||||
* Tensor parallelism support for distributed inference
|
* Tensor parallelism support for distributed inference
|
||||||
* Streaming outputs
|
* Streaming outputs
|
||||||
* OpenAI-compatible API server
|
* OpenAI-compatible API server
|
||||||
|
* Support NVIDIA GPUs and AMD GPUs
|
||||||
|
* (Experimental) Prefix caching support
|
||||||
|
* (Experimental) Multi-lora support
|
||||||
|
|
||||||
For more information, check out the following:
|
For more information, check out the following:
|
||||||
|
|
||||||
@@ -56,6 +61,7 @@ Documentation
|
|||||||
:caption: Getting Started
|
:caption: Getting Started
|
||||||
|
|
||||||
getting_started/installation
|
getting_started/installation
|
||||||
|
getting_started/amd-installation
|
||||||
getting_started/quickstart
|
getting_started/quickstart
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
@@ -64,7 +70,11 @@ Documentation
|
|||||||
|
|
||||||
serving/distributed_serving
|
serving/distributed_serving
|
||||||
serving/run_on_sky
|
serving/run_on_sky
|
||||||
|
serving/deploying_with_kserve
|
||||||
serving/deploying_with_triton
|
serving/deploying_with_triton
|
||||||
|
serving/deploying_with_docker
|
||||||
|
serving/serving_with_langchain
|
||||||
|
serving/metrics
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
@@ -72,3 +82,24 @@ Documentation
|
|||||||
|
|
||||||
models/supported_models
|
models/supported_models
|
||||||
models/adding_model
|
models/adding_model
|
||||||
|
models/engine_args
|
||||||
|
models/lora
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 1
|
||||||
|
:caption: Quantization
|
||||||
|
|
||||||
|
quantization/auto_awq
|
||||||
|
quantization/fp8_e5m2_kv_cache
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 2
|
||||||
|
:caption: Developer Documentation
|
||||||
|
|
||||||
|
dev/engine/engine_index
|
||||||
|
|
||||||
|
Indices and tables
|
||||||
|
==================
|
||||||
|
|
||||||
|
* :ref:`genindex`
|
||||||
|
* :ref:`modindex`
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ This document provides a high-level guide on integrating a `HuggingFace Transfor
|
|||||||
0. Fork the vLLM repository
|
0. Fork the vLLM repository
|
||||||
--------------------------------
|
--------------------------------
|
||||||
|
|
||||||
Start by forking our `GitHub <https://github.com/vllm-project/vllm/>`_ repository and then :ref:`build it from source <build_from_source>`.
|
Start by forking our `GitHub`_ repository and then :ref:`build it from source <build_from_source>`.
|
||||||
This gives you the ability to modify the codebase and test your model.
|
This gives you the ability to modify the codebase and test your model.
|
||||||
|
|
||||||
|
|
||||||
@@ -26,7 +26,7 @@ This gives you the ability to modify the codebase and test your model.
|
|||||||
------------------------
|
------------------------
|
||||||
|
|
||||||
Clone the PyTorch model code from the HuggingFace Transformers repository and put it into the `vllm/model_executor/models <https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models>`_ directory.
|
Clone the PyTorch model code from the HuggingFace Transformers repository and put it into the `vllm/model_executor/models <https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models>`_ directory.
|
||||||
For instance, vLLM's `OPT model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/opt.py>`_ was adpated from the HuggingFace's `modeling_opt.py <https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py>`_ file.
|
For instance, vLLM's `OPT model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/opt.py>`_ was adapted from the HuggingFace's `modeling_opt.py <https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py>`_ file.
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
When copying the model code, make sure to review and adhere to the code's copyright and licensing terms.
|
When copying the model code, make sure to review and adhere to the code's copyright and licensing terms.
|
||||||
@@ -58,35 +58,37 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
|
|||||||
+ positions: torch.Tensor,
|
+ positions: torch.Tensor,
|
||||||
+ kv_caches: List[KVCache],
|
+ kv_caches: List[KVCache],
|
||||||
+ input_metadata: InputMetadata,
|
+ input_metadata: InputMetadata,
|
||||||
+ cache_events: Optional[List[torch.cuda.Event]],
|
+) -> Optional[SamplerOutput]:
|
||||||
+) -> SamplerOutput:
|
|
||||||
|
|
||||||
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
|
1. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
|
||||||
4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture.
|
2. Replace the attention operation with either :code:`PagedAttention`, :code:`PagedAttentionWithRoPE`, or :code:`PagedAttentionWithALiBi` depending on the model's architecture.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings.
|
Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings.
|
||||||
If your model employs a different attention mechanism, you will need to implement a new attention layer in vLLM.
|
If your model employs a different attention mechanism, you will need to implement a new attention layer in vLLM.
|
||||||
|
|
||||||
|
|
||||||
3. (Optional) Implement tensor parallelism support
|
3. (Optional) Implement tensor parallelism and quantization support
|
||||||
--------------------------------------------------
|
-------------------------------------------------------------------
|
||||||
|
|
||||||
If your model is too large to fit into a single GPU, you can use tensor parallelism to manage it.
|
If your model is too large to fit into a single GPU, you can use tensor parallelism to manage it.
|
||||||
To do this, substitute your model's linear and embedding layers with their tensor-parallel versions.
|
To do this, substitute your model's linear and embedding layers with their tensor-parallel versions.
|
||||||
For the embedding layer, you can simply replace :code:`nn.Embedding` with :code:`VocabParallelEmbedding`.
|
For the embedding layer, you can simply replace :code:`nn.Embedding` with :code:`VocabParallelEmbedding`. For the output LM head, you can use :code:`ParallelLMHead`.
|
||||||
When it comes to the linear layers, you should use either :code:`RowParallelLinear` or :code:`ColumnParallelLinear`.
|
When it comes to the linear layers, we provide the following options to parallelize them:
|
||||||
Typically, :code:`ColumnParallelLinear` is used for QKV linear layers and the first linear layers of the MLP blocks.
|
|
||||||
For the remaining linear layers, :code:`RowParallelLinear` is used.
|
|
||||||
|
|
||||||
|
* :code:`ReplicatedLinear`: Replicates the inputs and weights across multiple GPUs. No memory saving.
|
||||||
|
* :code:`RowParallelLinear`: The input tensor is partitioned along the hidden dimension. The weight matrix is partitioned along the rows (input dimension). An *all-reduce* operation is performed after the matrix multiplication to reduce the results. Typically used for the second FFN layer and the output linear transformation of the attention layer.
|
||||||
|
* :code:`ColumnParallelLinear`: The input tensor is replicated. The weight matrix is partitioned along the columns (output dimension). The result is partitioned along the column dimension. Typically used for the first FFN layer and the separated QKV transformation of the attention layer in the original Transformer.
|
||||||
|
* :code:`MergedColumnParallelLinear`: Column-parallel linear that merges multiple `ColumnParallelLinear` operators. Typically used for the first FFN layer with weighted activation functions (e.g., SiLU). This class handles the sharded weight loading logic of multiple weight matrices.
|
||||||
|
* :code:`QKVParallelLinear`: Parallel linear layer for the query, key, and value projections of the multi-head and grouped-query attention mechanisms. When number of key/value heads are less than the world size, this class replicates the key/value heads properly. This class handles the weight loading and replication of the weight matrices.
|
||||||
|
|
||||||
|
Note that all the linear layers above take `linear_method` as an input. vLLM will set this parameter according to different quantization schemes to support weight quantization.
|
||||||
|
|
||||||
4. Implement the weight loading logic
|
4. Implement the weight loading logic
|
||||||
-------------------------------------
|
-------------------------------------
|
||||||
|
|
||||||
You now need to implement the :code:`load_weights` method in your :code:`*ForCausalLM` class.
|
You now need to implement the :code:`load_weights` method in your :code:`*ForCausalLM` class.
|
||||||
This method should load the weights from the HuggingFace's checkpoint file and assign them to the corresponding layers in your model.
|
This method should load the weights from the HuggingFace's checkpoint file and assign them to the corresponding layers in your model. Specifically, for `MergedColumnParallelLinear` and `QKVParallelLinear` layers, if the original model has separated weight matrices, you need to load the different parts separately.
|
||||||
While the process is straightforward for most layers, the tensor-parallel layers necessitate some additional care as their weights should be partitioned to multiple GPUs.
|
|
||||||
|
|
||||||
|
|
||||||
5. Register your model
|
5. Register your model
|
||||||
----------------------
|
----------------------
|
||||||
|
|||||||
116
docs/source/models/engine_args.rst
Normal file
116
docs/source/models/engine_args.rst
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
.. _engine_args:
|
||||||
|
|
||||||
|
Engine Arguments
|
||||||
|
================
|
||||||
|
|
||||||
|
Below, you can find an explanation of every engine argument for vLLM:
|
||||||
|
|
||||||
|
.. option:: --model <model_name_or_path>
|
||||||
|
|
||||||
|
Name or path of the huggingface model to use.
|
||||||
|
|
||||||
|
.. option:: --tokenizer <tokenizer_name_or_path>
|
||||||
|
|
||||||
|
Name or path of the huggingface tokenizer to use.
|
||||||
|
|
||||||
|
.. option:: --revision <revision>
|
||||||
|
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
|
||||||
|
|
||||||
|
.. option:: --tokenizer-revision <revision>
|
||||||
|
|
||||||
|
The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
|
||||||
|
|
||||||
|
.. option:: --tokenizer-mode {auto,slow}
|
||||||
|
|
||||||
|
The tokenizer mode.
|
||||||
|
|
||||||
|
* "auto" will use the fast tokenizer if available.
|
||||||
|
* "slow" will always use the slow tokenizer.
|
||||||
|
|
||||||
|
.. option:: --trust-remote-code
|
||||||
|
|
||||||
|
Trust remote code from huggingface.
|
||||||
|
|
||||||
|
.. option:: --download-dir <directory>
|
||||||
|
|
||||||
|
Directory to download and load the weights, default to the default cache dir of huggingface.
|
||||||
|
|
||||||
|
.. option:: --load-format {auto,pt,safetensors,npcache,dummy}
|
||||||
|
|
||||||
|
The format of the model weights to load.
|
||||||
|
|
||||||
|
* "auto" will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available.
|
||||||
|
* "pt" will load the weights in the pytorch bin format.
|
||||||
|
* "safetensors" will load the weights in the safetensors format.
|
||||||
|
* "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading.
|
||||||
|
* "dummy" will initialize the weights with random values, mainly for profiling.
|
||||||
|
|
||||||
|
.. option:: --dtype {auto,half,float16,bfloat16,float,float32}
|
||||||
|
|
||||||
|
Data type for model weights and activations.
|
||||||
|
|
||||||
|
* "auto" will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.
|
||||||
|
* "half" for FP16. Recommended for AWQ quantization.
|
||||||
|
* "float16" is the same as "half".
|
||||||
|
* "bfloat16" for a balance between precision and range.
|
||||||
|
* "float" is shorthand for FP32 precision.
|
||||||
|
* "float32" for FP32 precision.
|
||||||
|
|
||||||
|
.. option:: --max-model-len <length>
|
||||||
|
|
||||||
|
Model context length. If unspecified, will be automatically derived from the model config.
|
||||||
|
|
||||||
|
.. option:: --worker-use-ray
|
||||||
|
|
||||||
|
Use Ray for distributed serving, will be automatically set when using more than 1 GPU.
|
||||||
|
|
||||||
|
.. option:: --pipeline-parallel-size (-pp) <size>
|
||||||
|
|
||||||
|
Number of pipeline stages.
|
||||||
|
|
||||||
|
.. option:: --tensor-parallel-size (-tp) <size>
|
||||||
|
|
||||||
|
Number of tensor parallel replicas.
|
||||||
|
|
||||||
|
.. option:: --max-parallel-loading-workers <workers>
|
||||||
|
|
||||||
|
Load model sequentially in multiple batches, to avoid RAM OOM when using tensor parallel and large models.
|
||||||
|
|
||||||
|
.. option:: --block-size {8,16,32}
|
||||||
|
|
||||||
|
Token block size for contiguous chunks of tokens.
|
||||||
|
|
||||||
|
.. option:: --seed <seed>
|
||||||
|
|
||||||
|
Random seed for operations.
|
||||||
|
|
||||||
|
.. option:: --swap-space <size>
|
||||||
|
|
||||||
|
CPU swap space size (GiB) per GPU.
|
||||||
|
|
||||||
|
.. option:: --gpu-memory-utilization <fraction>
|
||||||
|
|
||||||
|
The fraction of GPU memory to be used for the model executor, which can range from 0 to 1.
|
||||||
|
For example, a value of 0.5 would imply 50% GPU memory utilization.
|
||||||
|
If unspecified, will use the default value of 0.9.
|
||||||
|
|
||||||
|
.. option:: --max-num-batched-tokens <tokens>
|
||||||
|
|
||||||
|
Maximum number of batched tokens per iteration.
|
||||||
|
|
||||||
|
.. option:: --max-num-seqs <sequences>
|
||||||
|
|
||||||
|
Maximum number of sequences per iteration.
|
||||||
|
|
||||||
|
.. option:: --max-paddings <paddings>
|
||||||
|
|
||||||
|
Maximum number of paddings in a batch.
|
||||||
|
|
||||||
|
.. option:: --disable-log-stats
|
||||||
|
|
||||||
|
Disable logging statistics.
|
||||||
|
|
||||||
|
.. option:: --quantization (-q) {awq,squeezellm,None}
|
||||||
|
|
||||||
|
Method used to quantize the weights.
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user