Compare commits
279 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8f89d72090 | ||
|
|
99dac099ab | ||
|
|
c4bd03c7c5 | ||
|
|
dcbf4286af | ||
|
|
00e6a2dc53 | ||
|
|
2e02311a1b | ||
|
|
89ec06c33b | ||
|
|
9fde251bf0 | ||
|
|
4c2ffb28ff | ||
|
|
246598a6b1 | ||
|
|
8bab4959be | ||
|
|
3c4cebf751 | ||
|
|
d8f31f2f8b | ||
|
|
640052b069 | ||
|
|
351d5e7b82 | ||
|
|
a008629807 | ||
|
|
76477a93b7 | ||
|
|
77c87beb06 | ||
|
|
114332b88e | ||
|
|
cb77ad836f | ||
|
|
856c990041 | ||
|
|
c5602f0baa | ||
|
|
f7f9c5f97b | ||
|
|
2c0d933594 | ||
|
|
774d1035e4 | ||
|
|
6b29d6fe70 | ||
|
|
0bfa1c4f13 | ||
|
|
c81da5f56d | ||
|
|
68bc81703e | ||
|
|
5884c2b454 | ||
|
|
45f92c00cf | ||
|
|
5467ac3196 | ||
|
|
5d7e3d0176 | ||
|
|
0373e1837e | ||
|
|
c09dade2a2 | ||
|
|
8ea5e44a43 | ||
|
|
9fb900f90c | ||
|
|
c96fc06747 | ||
|
|
b3376e5c76 | ||
|
|
e69ded7d1c | ||
|
|
767c727a81 | ||
|
|
6840a71610 | ||
|
|
7a9cb294ae | ||
|
|
ca3ea51bde | ||
|
|
dc49fb892c | ||
|
|
18a277b52d | ||
|
|
8d75fe48ca | ||
|
|
388596c914 | ||
|
|
baa15a9ec3 | ||
|
|
15063741e3 | ||
|
|
ccdc490dda | ||
|
|
a31cab7556 | ||
|
|
828da0d44e | ||
|
|
abe855d637 | ||
|
|
4efff036f0 | ||
|
|
89c920785f | ||
|
|
7b0a0dfb22 | ||
|
|
3a6ae1d33c | ||
|
|
8f1729b829 | ||
|
|
6a7c7711a2 | ||
|
|
0f83ddd4d7 | ||
|
|
065aff6c16 | ||
|
|
3d33e372a1 | ||
|
|
faf71bcd4b | ||
|
|
f270a39537 | ||
|
|
51a08e7d8f | ||
|
|
eb8fcd2666 | ||
|
|
5563a4dea8 | ||
|
|
ccd4f129e8 | ||
|
|
02cc3b51a7 | ||
|
|
d5b1eb081e | ||
|
|
f0a500545f | ||
|
|
c65146e75e | ||
|
|
41ca62cf03 | ||
|
|
974fc9b845 | ||
|
|
fee4dcc33a | ||
|
|
650a4cc55e | ||
|
|
9ca62d8668 | ||
|
|
45c35f0d58 | ||
|
|
9ba093b4f4 | ||
|
|
27208be66e | ||
|
|
87d5abef75 | ||
|
|
ec784b2526 | ||
|
|
a58f24e590 | ||
|
|
f42a006b15 | ||
|
|
3a434b07ed | ||
|
|
bd0e7802e0 | ||
|
|
06b2550cbb | ||
|
|
f775a07e30 | ||
|
|
4f0d17c05c | ||
|
|
10c38e3e46 | ||
|
|
cafb8e06c5 | ||
|
|
cbb2f59cc8 | ||
|
|
0ab278ca31 | ||
|
|
7a64d24aad | ||
|
|
dfbe60dc62 | ||
|
|
a66cf40b20 | ||
|
|
f790ad3c50 | ||
|
|
ed59a7ed23 | ||
|
|
044793d8df | ||
|
|
c2d6d2f960 | ||
|
|
8279078e21 | ||
|
|
b9c0605a8e | ||
|
|
37464a0f74 | ||
|
|
c354072828 | ||
|
|
f081c3ce4b | ||
|
|
260d119e86 | ||
|
|
a360ff80bb | ||
|
|
1197e02141 | ||
|
|
657579113f | ||
|
|
e9899fb7a4 | ||
|
|
a377f0bd5e | ||
|
|
e9d3aa04f6 | ||
|
|
a22dea54d3 | ||
|
|
533c217792 | ||
|
|
6d21fa1cad | ||
|
|
b35be5403f | ||
|
|
45a1a69b98 | ||
|
|
87a658c812 | ||
|
|
429d89720e | ||
|
|
a9bcc7afb2 | ||
|
|
d79d9eaaff | ||
|
|
f758505c73 | ||
|
|
d910816c73 | ||
|
|
87d41c849d | ||
|
|
e07aff9e52 | ||
|
|
5bf185a1c4 | ||
|
|
4fbcb0f27e | ||
|
|
7c3604fb68 | ||
|
|
b1c255630d | ||
|
|
eb6c50cdc2 | ||
|
|
eecd864388 | ||
|
|
ae495c74ea | ||
|
|
4238bc82f2 | ||
|
|
594392d27a | ||
|
|
18c1f16d86 | ||
|
|
5bd3c65072 | ||
|
|
616e600e0b | ||
|
|
dfba529b40 | ||
|
|
5ae5ed1e60 | ||
|
|
290f4ada2b | ||
|
|
dd8de11f0a | ||
|
|
9ba415588a | ||
|
|
d4f3985907 | ||
|
|
890aa93d27 | ||
|
|
fbdb7b3ee2 | ||
|
|
1102bef219 | ||
|
|
f17a1a8f96 | ||
|
|
d5a1697772 | ||
|
|
325c119961 | ||
|
|
8e192ff967 | ||
|
|
e64fde4b01 | ||
|
|
919770957f | ||
|
|
6a50f4cafa | ||
|
|
e3470f8753 | ||
|
|
a1242324c9 | ||
|
|
5eda2ea02a | ||
|
|
2ba80bed27 | ||
|
|
6066253296 | ||
|
|
ee3eea0a1b | ||
|
|
a36de682d4 | ||
|
|
eb6d3c264d | ||
|
|
97b030005c | ||
|
|
a3a73ab069 | ||
|
|
8674f9880e | ||
|
|
c74c913bfb | ||
|
|
5f6d10c14c | ||
|
|
9b9a10d6cb | ||
|
|
99eff67ba9 | ||
|
|
14772eeb8e | ||
|
|
757b62c495 | ||
|
|
e941f88584 | ||
|
|
f12c3b5b3d | ||
|
|
d130b573a0 | ||
|
|
65ae8c2c8f | ||
|
|
c3af44722c | ||
|
|
1937e29848 | ||
|
|
f0eecee610 | ||
|
|
943e72ca56 | ||
|
|
546a97ef69 | ||
|
|
da5a0b539d | ||
|
|
6287537a0c | ||
|
|
b57e6c5949 | ||
|
|
27ce85476e | ||
|
|
f68470e803 | ||
|
|
2e9a2227ec | ||
|
|
c0724fc915 | ||
|
|
86b45ae065 | ||
|
|
c5711ef985 | ||
|
|
48d5985a08 | ||
|
|
33e0823de5 | ||
|
|
26148120b3 | ||
|
|
0150a10630 | ||
|
|
8e7fb5d43a | ||
|
|
9a31a817a8 | ||
|
|
2060e93659 | ||
|
|
8435b207af | ||
|
|
10fa9eea21 | ||
|
|
e08188081b | ||
|
|
b5853f9963 | ||
|
|
f09edd8a25 | ||
|
|
6979ade384 | ||
|
|
9216b9cc38 | ||
|
|
5e0391c040 | ||
|
|
dbc0754ddf | ||
|
|
99caa49106 | ||
|
|
5c342570d7 | ||
|
|
973617ae02 | ||
|
|
30e754390c | ||
|
|
52f8107cf2 | ||
|
|
fc0d9dfc3a | ||
|
|
361c461a12 | ||
|
|
a5675d348b | ||
|
|
e9cdd2b1e2 | ||
|
|
65bf2ac165 | ||
|
|
8a7cc254a0 | ||
|
|
29bc01bf3b | ||
|
|
676a99982f | ||
|
|
dc72402b57 | ||
|
|
ccb63a8245 | ||
|
|
c579b750a0 | ||
|
|
4bfa7e7f75 | ||
|
|
ac1fbf7fd2 | ||
|
|
33d3914b1e | ||
|
|
1356df53bd | ||
|
|
ce532ff45c | ||
|
|
8bc68e198c | ||
|
|
0fca3cdcf2 | ||
|
|
e7c46b9527 | ||
|
|
350f9e107f | ||
|
|
702bee461f | ||
|
|
a7be4d0072 | ||
|
|
a709e87a4f | ||
|
|
6eaccb7353 | ||
|
|
e254497b66 | ||
|
|
4e12131089 | ||
|
|
fcc2994be6 | ||
|
|
2e7796f2cf | ||
|
|
706588a77d | ||
|
|
6a0f617210 | ||
|
|
dac6a3f6ed | ||
|
|
64b77dfd7e | ||
|
|
51d4094fda | ||
|
|
e965d46184 | ||
|
|
208b71bcc1 | ||
|
|
c833101740 | ||
|
|
379da6dcb5 | ||
|
|
ebce310b74 | ||
|
|
be0c5180ac | ||
|
|
cea64430f6 | ||
|
|
a3c124570a | ||
|
|
ff5abcd746 | ||
|
|
0ee535b294 | ||
|
|
190bc838e1 | ||
|
|
f12b20decc | ||
|
|
16bc0a098f | ||
|
|
e288df0632 | ||
|
|
8b9241be3a | ||
|
|
f942efb5a3 | ||
|
|
89579a201f | ||
|
|
230c4b38c1 | ||
|
|
20cfcdec99 | ||
|
|
ad932a221d | ||
|
|
5510cf0e8a | ||
|
|
0f9a6e3d22 | ||
|
|
f6a593093a | ||
|
|
d7740ea4dc | ||
|
|
cc466a3290 | ||
|
|
8344f7742b | ||
|
|
469f85c782 | ||
|
|
10760da800 | ||
|
|
478aed5827 | ||
|
|
63575bc2e1 | ||
|
|
a98187cf72 | ||
|
|
bd99d22629 | ||
|
|
19cb4716ee | ||
|
|
e186d37cb1 | ||
|
|
323f27b904 | ||
|
|
0650e5935b |
@@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
MAX_SIZE_MB = 100
|
MAX_SIZE_MB = 200
|
||||||
|
|
||||||
|
|
||||||
def print_top_10_largest_files(zip_file):
|
def print_top_10_largest_files(zip_file):
|
||||||
|
|||||||
26
.buildkite/nightly-benchmarks/kickoff-pipeline.sh
Executable file
26
.buildkite/nightly-benchmarks/kickoff-pipeline.sh
Executable file
@@ -0,0 +1,26 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Install system packages
|
||||||
|
apt update
|
||||||
|
apt install -y curl jq
|
||||||
|
|
||||||
|
# Install minijinja for templating
|
||||||
|
curl -sSfL https://github.com/mitsuhiko/minijinja/releases/latest/download/minijinja-cli-installer.sh | sh
|
||||||
|
source $HOME/.cargo/env
|
||||||
|
|
||||||
|
# If BUILDKITE_PULL_REQUEST != "false", then we check the PR labels using curl and jq
|
||||||
|
if [ "$BUILDKITE_PULL_REQUEST" != "false" ]; then
|
||||||
|
PR_LABELS=$(curl -s "https://api.github.com/repos/vllm-project/vllm/pulls/$BUILDKITE_PULL_REQUEST" | jq -r '.labels[].name')
|
||||||
|
|
||||||
|
if [[ $PR_LABELS == *"perf-benchmarks"* ]]; then
|
||||||
|
echo "This PR has the 'perf-benchmarks' label. Proceeding with the nightly benchmarks."
|
||||||
|
else
|
||||||
|
echo "This PR does not have the 'perf-benchmarks' label. Skipping the nightly benchmarks."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Upload sample.yaml
|
||||||
|
buildkite-agent pipeline upload .buildkite/nightly-benchmarks/sample.yaml
|
||||||
39
.buildkite/nightly-benchmarks/sample.yaml
Normal file
39
.buildkite/nightly-benchmarks/sample.yaml
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
steps:
|
||||||
|
# NOTE(simon): You can create separate blocks for different jobs
|
||||||
|
- label: "A100: NVIDIA SMI"
|
||||||
|
agents:
|
||||||
|
queue: A100
|
||||||
|
plugins:
|
||||||
|
- kubernetes:
|
||||||
|
podSpec:
|
||||||
|
containers:
|
||||||
|
# - image: us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:$BUILDKITE_COMMIT
|
||||||
|
# TODO(simon): check latest main branch or use the PR image.
|
||||||
|
- image: us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:45c35f0d58f4508bf43bd6af1d3d0d0ec0c915e6
|
||||||
|
command:
|
||||||
|
- bash -c 'nvidia-smi && nvidia-smi topo -m && pwd && ls'
|
||||||
|
resources:
|
||||||
|
limits:
|
||||||
|
nvidia.com/gpu: 8
|
||||||
|
volumeMounts:
|
||||||
|
- name: devshm
|
||||||
|
mountPath: /dev/shm
|
||||||
|
nodeSelector:
|
||||||
|
nvidia.com/gpu.product: NVIDIA-A100-SXM4-80GB
|
||||||
|
volumes:
|
||||||
|
- name: devshm
|
||||||
|
emptyDir:
|
||||||
|
medium: Memory
|
||||||
|
# TODO(simon): bring H100 online
|
||||||
|
# - label: "H100: NVIDIA SMI"
|
||||||
|
# agents:
|
||||||
|
# queue: H100
|
||||||
|
# plugins:
|
||||||
|
# - docker#v5.11.0:
|
||||||
|
# image: us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:45c35f0d58f4508bf43bd6af1d3d0d0ec0c915e6
|
||||||
|
# command:
|
||||||
|
# - bash -c 'nvidia-smi && nvidia-smi topo -m'
|
||||||
|
# propagate-environment: true
|
||||||
|
# ipc: host
|
||||||
|
# gpus: all
|
||||||
|
|
||||||
@@ -1,10 +1,38 @@
|
|||||||
# This script build the ROCm docker image and runs test inside it.
|
# This script runs test inside the corresponding ROCm docker container.
|
||||||
set -ex
|
set -ex
|
||||||
|
|
||||||
# Print ROCm version
|
# Print ROCm version
|
||||||
echo "--- ROCm info"
|
echo "--- ROCm info"
|
||||||
rocminfo
|
rocminfo
|
||||||
|
|
||||||
|
# cleanup older docker images
|
||||||
|
cleanup_docker() {
|
||||||
|
# Get Docker's root directory
|
||||||
|
docker_root=$(docker info -f '{{.DockerRootDir}}')
|
||||||
|
if [ -z "$docker_root" ]; then
|
||||||
|
echo "Failed to determine Docker root directory."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "Docker root directory: $docker_root"
|
||||||
|
# Check disk usage of the filesystem where Docker's root directory is located
|
||||||
|
disk_usage=$(df "$docker_root" | tail -1 | awk '{print $5}' | sed 's/%//')
|
||||||
|
# Define the threshold
|
||||||
|
threshold=70
|
||||||
|
if [ "$disk_usage" -gt "$threshold" ]; then
|
||||||
|
echo "Disk usage is above $threshold%. Cleaning up Docker images and volumes..."
|
||||||
|
# Remove dangling images (those that are not tagged and not used by any container)
|
||||||
|
docker image prune -f
|
||||||
|
# Remove unused volumes
|
||||||
|
docker volume prune -f
|
||||||
|
echo "Docker images and volumes cleanup completed."
|
||||||
|
else
|
||||||
|
echo "Disk usage is below $threshold%. No cleanup needed."
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call the cleanup docker function
|
||||||
|
cleanup_docker
|
||||||
|
|
||||||
echo "--- Resetting GPUs"
|
echo "--- Resetting GPUs"
|
||||||
|
|
||||||
echo "reset" > /opt/amdgpu/etc/gpu_state
|
echo "reset" > /opt/amdgpu/etc/gpu_state
|
||||||
@@ -19,15 +47,16 @@ done
|
|||||||
|
|
||||||
echo "--- Building container"
|
echo "--- Building container"
|
||||||
sha=$(git rev-parse --short HEAD)
|
sha=$(git rev-parse --short HEAD)
|
||||||
container_name=rocm_${sha}
|
image_name=rocm_${sha}
|
||||||
|
container_name=rocm_${sha}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)
|
||||||
docker build \
|
docker build \
|
||||||
-t ${container_name} \
|
-t ${image_name} \
|
||||||
-f Dockerfile.rocm \
|
-f Dockerfile.rocm \
|
||||||
--progress plain \
|
--progress plain \
|
||||||
.
|
.
|
||||||
|
|
||||||
remove_docker_container() {
|
remove_docker_container() {
|
||||||
docker rm -f ${container_name} || docker image rm -f ${container_name} || true
|
docker rm -f ${container_name} || docker image rm -f ${image_name} || true
|
||||||
}
|
}
|
||||||
trap remove_docker_container EXIT
|
trap remove_docker_container EXIT
|
||||||
|
|
||||||
@@ -39,6 +68,6 @@ docker run \
|
|||||||
--rm \
|
--rm \
|
||||||
-e HF_TOKEN \
|
-e HF_TOKEN \
|
||||||
--name ${container_name} \
|
--name ${container_name} \
|
||||||
${container_name} \
|
${image_name} \
|
||||||
/bin/bash -c $(echo $1 | sed "s/^'//" | sed "s/'$//")
|
/bin/bash -c "${@}"
|
||||||
|
|
||||||
|
|||||||
@@ -9,10 +9,10 @@ cd "$(dirname "${BASH_SOURCE[0]}")/.."
|
|||||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||||
|
|
||||||
# run python-based benchmarks and upload the result to buildkite
|
# run python-based benchmarks and upload the result to buildkite
|
||||||
python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt
|
python3 benchmarks/benchmark_latency.py --output-json latency_results.json 2>&1 | tee benchmark_latency.txt
|
||||||
bench_latency_exit_code=$?
|
bench_latency_exit_code=$?
|
||||||
|
|
||||||
python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt
|
python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --output-json throughput_results.json 2>&1 | tee benchmark_throughput.txt
|
||||||
bench_throughput_exit_code=$?
|
bench_throughput_exit_code=$?
|
||||||
|
|
||||||
# run server-based benchmarks and upload the result to buildkite
|
# run server-based benchmarks and upload the result to buildkite
|
||||||
@@ -50,16 +50,16 @@ echo "### Serving Benchmarks" >> benchmark_results.md
|
|||||||
sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line
|
sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line
|
||||||
echo "" >> benchmark_results.md
|
echo "" >> benchmark_results.md
|
||||||
echo '```' >> benchmark_results.md
|
echo '```' >> benchmark_results.md
|
||||||
tail -n 20 benchmark_serving.txt >> benchmark_results.md # last 20 lines
|
tail -n 24 benchmark_serving.txt >> benchmark_results.md # last 24 lines
|
||||||
echo '```' >> benchmark_results.md
|
echo '```' >> benchmark_results.md
|
||||||
|
|
||||||
# if the agent binary is not found, skip uploading the results, exit 0
|
# if the agent binary is not found, skip uploading the results, exit 0
|
||||||
if [ ! -f /workspace/buildkite-agent ]; then
|
if [ ! -f /usr/bin/buildkite-agent ]; then
|
||||||
exit 0
|
exit 0
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# upload the results to buildkite
|
# upload the results to buildkite
|
||||||
/workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md
|
buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md
|
||||||
|
|
||||||
# exit with the exit code of the benchmarks
|
# exit with the exit code of the benchmarks
|
||||||
if [ $bench_latency_exit_code -ne 0 ]; then
|
if [ $bench_latency_exit_code -ne 0 ]; then
|
||||||
@@ -74,4 +74,5 @@ if [ $bench_serving_exit_code -ne 0 ]; then
|
|||||||
exit $bench_serving_exit_code
|
exit $bench_serving_exit_code
|
||||||
fi
|
fi
|
||||||
|
|
||||||
/workspace/buildkite-agent artifact upload openai-*.json
|
rm ShareGPT_V3_unfiltered_cleaned_split.json
|
||||||
|
buildkite-agent artifact upload "*.json"
|
||||||
|
|||||||
@@ -10,5 +10,15 @@ remove_docker_container() { docker rm -f cpu-test || true; }
|
|||||||
trap remove_docker_container EXIT
|
trap remove_docker_container EXIT
|
||||||
remove_docker_container
|
remove_docker_container
|
||||||
|
|
||||||
# Run the image and launch offline inference
|
# Run the image
|
||||||
docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 examples/offline_inference.py
|
docker run -itd -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 --cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test
|
||||||
|
|
||||||
|
# offline inference
|
||||||
|
docker exec cpu-test bash -c "python3 examples/offline_inference.py"
|
||||||
|
|
||||||
|
# Run basic model test
|
||||||
|
docker exec cpu-test bash -c "cd tests;
|
||||||
|
pip install pytest Pillow protobuf
|
||||||
|
bash ../.buildkite/download-images.sh
|
||||||
|
cd ../
|
||||||
|
pytest -v -s tests/models --ignore=tests/models/test_llava.py --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py"
|
||||||
|
|||||||
@@ -5,13 +5,16 @@
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- label: Regression Test
|
- label: Regression Test
|
||||||
|
mirror_hardwares: [amd]
|
||||||
command: pytest -v -s test_regression.py
|
command: pytest -v -s test_regression.py
|
||||||
working_dir: "/vllm-workspace/tests" # optional
|
working_dir: "/vllm-workspace/tests" # optional
|
||||||
|
|
||||||
- label: AsyncEngine Test
|
- label: AsyncEngine Test
|
||||||
|
#mirror_hardwares: [amd]
|
||||||
command: pytest -v -s async_engine
|
command: pytest -v -s async_engine
|
||||||
|
|
||||||
- label: Basic Correctness Test
|
- label: Basic Correctness Test
|
||||||
|
mirror_hardwares: [amd]
|
||||||
commands:
|
commands:
|
||||||
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
|
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
|
||||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
|
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
|
||||||
@@ -24,65 +27,81 @@ steps:
|
|||||||
command: pytest -v -s core
|
command: pytest -v -s core
|
||||||
|
|
||||||
- label: Distributed Comm Ops Test
|
- label: Distributed Comm Ops Test
|
||||||
command: pytest -v -s test_comm_ops.py
|
#mirror_hardwares: [amd]
|
||||||
working_dir: "/vllm-workspace/tests/distributed"
|
command: pytest -v -s distributed/test_comm_ops.py
|
||||||
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
|
|
||||||
- label: Distributed Tests
|
- label: Distributed Tests
|
||||||
working_dir: "/vllm-workspace/tests/distributed"
|
|
||||||
|
|
||||||
num_gpus: 2 # only support 1 or 2 for now.
|
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
|
working_dir: "/vllm-workspace/tests"
|
||||||
|
num_gpus: 2
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s test_pynccl_library.py
|
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
|
||||||
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
|
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
|
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
|
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py
|
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||||
|
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
|
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
|
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||||
|
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||||
|
- pytest -v -s spec_decode/e2e/test_integration_dist.py
|
||||||
|
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||||
|
|
||||||
- label: Distributed Tests (Multiple Groups)
|
- label: Distributed Tests (Multiple Groups)
|
||||||
working_dir: "/vllm-workspace/tests/distributed"
|
#mirror_hardwares: [amd]
|
||||||
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 4
|
num_gpus: 4
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s test_pynccl.py
|
- pytest -v -s distributed/test_pynccl.py
|
||||||
|
|
||||||
- label: Engine Test
|
- label: Engine Test
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
|
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
|
||||||
|
|
||||||
- label: Entrypoints Test
|
- label: Entrypoints Test
|
||||||
|
mirror_hardwares: [amd]
|
||||||
|
|
||||||
commands:
|
commands:
|
||||||
# these tests have to be separated, because each one will allocate all posible GPU memory
|
- pytest -v -s entrypoints -m llm
|
||||||
- pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py
|
- pytest -v -s entrypoints -m openai
|
||||||
- pytest -v -s entrypoints/test_server_oot_registration.py
|
|
||||||
|
|
||||||
- label: Examples Test
|
- label: Examples Test
|
||||||
working_dir: "/vllm-workspace/examples"
|
working_dir: "/vllm-workspace/examples"
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
commands:
|
commands:
|
||||||
# install aws cli for llava_example.py
|
# install aws cli for llava_example.py
|
||||||
- pip install awscli
|
# install tensorizer for tensorize_vllm_model.py
|
||||||
|
- pip install awscli tensorizer
|
||||||
- python3 offline_inference.py
|
- python3 offline_inference.py
|
||||||
- python3 offline_inference_with_prefix.py
|
- python3 offline_inference_with_prefix.py
|
||||||
- python3 llm_engine_example.py
|
- python3 llm_engine_example.py
|
||||||
- python3 llava_example.py
|
- python3 llava_example.py
|
||||||
|
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||||
|
|
||||||
|
- label: Inputs Test
|
||||||
|
#mirror_hardwares: [amd]
|
||||||
|
commands:
|
||||||
|
- bash ../.buildkite/download-images.sh
|
||||||
|
- pytest -v -s test_inputs.py
|
||||||
|
- pytest -v -s multimodal
|
||||||
|
|
||||||
- label: Kernels Test %N
|
- label: Kernels Test %N
|
||||||
|
#mirror_hardwares: [amd]
|
||||||
command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||||
parallelism: 4
|
parallelism: 4
|
||||||
|
|
||||||
- label: Models Test
|
- label: Models Test
|
||||||
mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
commands:
|
commands:
|
||||||
- bash ../.buildkite/download-images.sh
|
- pytest -v -s models -m \"not llava\"
|
||||||
- pytest -v -s models --ignore=models/test_llava.py --ignore=models/test_mistral.py
|
|
||||||
|
|
||||||
- label: Llava Test
|
- label: Llava Test
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
commands:
|
commands:
|
||||||
- bash ../.buildkite/download-images.sh
|
- bash ../.buildkite/download-images.sh
|
||||||
- pytest -v -s models/test_llava.py
|
- pytest -v -s models -m llava
|
||||||
|
|
||||||
- label: Prefix Caching Test
|
- label: Prefix Caching Test
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
@@ -90,31 +109,49 @@ steps:
|
|||||||
- pytest -v -s prefix_caching
|
- pytest -v -s prefix_caching
|
||||||
|
|
||||||
- label: Samplers Test
|
- label: Samplers Test
|
||||||
|
#mirror_hardwares: [amd]
|
||||||
command: pytest -v -s samplers
|
command: pytest -v -s samplers
|
||||||
|
|
||||||
- label: LogitsProcessor Test
|
- label: LogitsProcessor Test
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
command: pytest -v -s test_logits_processor.py
|
command: pytest -v -s test_logits_processor.py
|
||||||
|
|
||||||
|
- label: Utils Test
|
||||||
|
command: pytest -v -s test_utils.py
|
||||||
|
|
||||||
- label: Worker Test
|
- label: Worker Test
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
command: pytest -v -s worker
|
command: pytest -v -s worker
|
||||||
|
|
||||||
- label: Speculative decoding tests
|
- label: Speculative decoding tests
|
||||||
mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
command: pytest -v -s spec_decode
|
commands:
|
||||||
|
# See https://github.com/vllm-project/vllm/issues/5152
|
||||||
|
- export VLLM_ATTENTION_BACKEND=XFORMERS
|
||||||
|
- pytest -v -s spec_decode
|
||||||
|
|
||||||
- label: LoRA Test %N
|
- label: LoRA Test %N
|
||||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
#mirror_hardwares: [amd]
|
||||||
|
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
|
||||||
parallelism: 4
|
parallelism: 4
|
||||||
|
|
||||||
|
- label: LoRA Long Context (Distributed)
|
||||||
|
#mirror_hardwares: [amd]
|
||||||
|
num_gpus: 4
|
||||||
|
# This test runs llama 13B, so it is required to run on 4 GPUs.
|
||||||
|
commands:
|
||||||
|
- pytest -v -s -x lora/test_long_context.py
|
||||||
|
|
||||||
- label: Tensorizer Test
|
- label: Tensorizer Test
|
||||||
|
#mirror_hardwares: [amd]
|
||||||
command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader
|
command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader
|
||||||
|
|
||||||
- label: Metrics Test
|
- label: Metrics Test
|
||||||
|
mirror_hardwares: [amd]
|
||||||
command: pytest -v -s metrics
|
command: pytest -v -s metrics
|
||||||
|
|
||||||
- label: Quantization Test
|
- label: Quantization Test
|
||||||
|
#mirror_hardwares: [amd]
|
||||||
command: pytest -v -s quantization
|
command: pytest -v -s quantization
|
||||||
|
|
||||||
- label: Benchmarks
|
- label: Benchmarks
|
||||||
|
|||||||
64
.buildkite/test-template-aws.j2
Normal file
64
.buildkite/test-template-aws.j2
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
{% set docker_image = "public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT" %}
|
||||||
|
{% set default_working_dir = "/vllm-workspace/tests" %}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- label: ":docker: build image"
|
||||||
|
agents:
|
||||||
|
queue: cpu_queue
|
||||||
|
commands:
|
||||||
|
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||||
|
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
|
||||||
|
- "docker push {{ docker_image }}"
|
||||||
|
env:
|
||||||
|
DOCKER_BUILDKIT: "1"
|
||||||
|
retry:
|
||||||
|
automatic:
|
||||||
|
- exit_status: -1 # Agent was lost
|
||||||
|
limit: 5
|
||||||
|
- exit_status: -10 # Agent was lost
|
||||||
|
limit: 5
|
||||||
|
- wait
|
||||||
|
|
||||||
|
{% for step in steps %}
|
||||||
|
- label: "{{ step.label }}"
|
||||||
|
agents:
|
||||||
|
{% if step.label == "Documentation Build" %}
|
||||||
|
queue: small_cpu_queue
|
||||||
|
{% elif step.no_gpu %}
|
||||||
|
queue: cpu_queue
|
||||||
|
{% elif step.num_gpus == 2 or step.num_gpus == 4 %}
|
||||||
|
queue: gpu_4_queue
|
||||||
|
{% else %}
|
||||||
|
queue: gpu_1_queue
|
||||||
|
{% endif %}
|
||||||
|
soft_fail: true
|
||||||
|
{% if step.parallelism %}
|
||||||
|
parallelism: {{ step.parallelism }}
|
||||||
|
{% endif %}
|
||||||
|
retry:
|
||||||
|
automatic:
|
||||||
|
- exit_status: -1 # Agent was lost
|
||||||
|
limit: 5
|
||||||
|
- exit_status: -10 # Agent was lost
|
||||||
|
limit: 5
|
||||||
|
plugins:
|
||||||
|
- docker#v5.2.0:
|
||||||
|
image: {{ docker_image }}
|
||||||
|
always-pull: true
|
||||||
|
propagate-environment: true
|
||||||
|
{% if not step.no_gpu %}
|
||||||
|
gpus: all
|
||||||
|
{% endif %}
|
||||||
|
{% if step.label == "Benchmarks" %}
|
||||||
|
mount-buildkite-agent: true
|
||||||
|
{% endif %}
|
||||||
|
command: ["bash", "-c", "cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}"]
|
||||||
|
environment:
|
||||||
|
- VLLM_USAGE_SOURCE=ci-test
|
||||||
|
- HF_TOKEN
|
||||||
|
{% if step.label == "Speculative decoding tests" %}
|
||||||
|
- VLLM_ATTENTION_BACKEND=XFORMERS
|
||||||
|
{% endif %}
|
||||||
|
volumes:
|
||||||
|
- /dev/shm:/dev/shm
|
||||||
|
{% endfor %}
|
||||||
@@ -3,7 +3,6 @@
|
|||||||
{% set default_working_dir = "/vllm-workspace/tests" %}
|
{% set default_working_dir = "/vllm-workspace/tests" %}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- label: ":docker: build image"
|
- label: ":docker: build image"
|
||||||
commands:
|
commands:
|
||||||
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
|
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
|
||||||
@@ -14,6 +13,8 @@ steps:
|
|||||||
automatic:
|
automatic:
|
||||||
- exit_status: -1 # Agent was lost
|
- exit_status: -1 # Agent was lost
|
||||||
limit: 5
|
limit: 5
|
||||||
|
- exit_status: -10 # Agent was lost
|
||||||
|
limit: 5
|
||||||
- wait
|
- wait
|
||||||
|
|
||||||
- group: "AMD Tests"
|
- group: "AMD Tests"
|
||||||
@@ -24,9 +25,10 @@ steps:
|
|||||||
- label: "AMD: {{ step.label }}"
|
- label: "AMD: {{ step.label }}"
|
||||||
agents:
|
agents:
|
||||||
queue: amd
|
queue: amd
|
||||||
command: bash .buildkite/run-amd-test.sh "'cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}'"
|
command: bash .buildkite/run-amd-test.sh "cd {{ (step.working_dir or default_working_dir) | safe }} ; {{ step.command or (step.commands | join(" ; ")) | safe }}"
|
||||||
env:
|
env:
|
||||||
DOCKER_BUILDKIT: "1"
|
DOCKER_BUILDKIT: "1"
|
||||||
|
soft_fail: true
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|
||||||
@@ -35,10 +37,12 @@ steps:
|
|||||||
agents:
|
agents:
|
||||||
queue: neuron
|
queue: neuron
|
||||||
command: bash .buildkite/run-neuron-test.sh
|
command: bash .buildkite/run-neuron-test.sh
|
||||||
soft_fail: true
|
soft_fail: false
|
||||||
|
|
||||||
- label: "Intel Test"
|
- label: "Intel Test"
|
||||||
depends_on: ~
|
depends_on: ~
|
||||||
|
agents:
|
||||||
|
queue: intel
|
||||||
command: bash .buildkite/run-cpu-test.sh
|
command: bash .buildkite/run-cpu-test.sh
|
||||||
|
|
||||||
{% for step in steps %}
|
{% for step in steps %}
|
||||||
@@ -53,6 +57,8 @@ steps:
|
|||||||
automatic:
|
automatic:
|
||||||
- exit_status: -1 # Agent was lost
|
- exit_status: -1 # Agent was lost
|
||||||
limit: 5
|
limit: 5
|
||||||
|
- exit_status: -10 # Agent was lost
|
||||||
|
limit: 5
|
||||||
plugins:
|
plugins:
|
||||||
- kubernetes:
|
- kubernetes:
|
||||||
podSpec:
|
podSpec:
|
||||||
|
|||||||
26
.clang-format
Normal file
26
.clang-format
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
BasedOnStyle: Google
|
||||||
|
UseTab: Never
|
||||||
|
IndentWidth: 2
|
||||||
|
ColumnLimit: 80
|
||||||
|
|
||||||
|
# Force pointers to the type for C++.
|
||||||
|
DerivePointerAlignment: false
|
||||||
|
PointerAlignment: Left
|
||||||
|
|
||||||
|
# Reordering #include statements can (and currently will) introduce errors
|
||||||
|
SortIncludes: false
|
||||||
|
|
||||||
|
# Style choices
|
||||||
|
AlignConsecutiveAssignments: false
|
||||||
|
AlignConsecutiveDeclarations: false
|
||||||
|
IndentPPDirectives: BeforeHash
|
||||||
|
|
||||||
|
IncludeCategories:
|
||||||
|
- Regex: '^<'
|
||||||
|
Priority: 4
|
||||||
|
- Regex: '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/'
|
||||||
|
Priority: 3
|
||||||
|
- Regex: '^"(qoda|\.\.)/'
|
||||||
|
Priority: 2
|
||||||
|
- Regex: '.*'
|
||||||
|
Priority: 1
|
||||||
2
.github/ISSUE_TEMPLATE/400-bug report.yml
vendored
2
.github/ISSUE_TEMPLATE/400-bug report.yml
vendored
@@ -59,6 +59,8 @@ body:
|
|||||||
|
|
||||||
Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
|
Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
|
||||||
|
|
||||||
|
Please set the environment variable `export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging to help debugging potential issues.
|
||||||
|
|
||||||
If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs.
|
If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs.
|
||||||
placeholder: |
|
placeholder: |
|
||||||
A clear and concise description of what the bug is.
|
A clear and concise description of what the bug is.
|
||||||
|
|||||||
42
.github/workflows/clang-format.yml
vendored
Normal file
42
.github/workflows/clang-format.yml
vendored
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
name: clang-format
|
||||||
|
|
||||||
|
on:
|
||||||
|
# Trigger the workflow on push or pull request,
|
||||||
|
# but only for the main branch
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
clang-format:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.11"]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install clang-format==18.1.5
|
||||||
|
- name: Running clang-format
|
||||||
|
run: |
|
||||||
|
EXCLUDES=(
|
||||||
|
'csrc/moe/topk_softmax_kernels.cu'
|
||||||
|
'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu'
|
||||||
|
'csrc/punica/bgmv/bgmv_config.h'
|
||||||
|
'csrc/punica/bgmv/bgmv_impl.cuh'
|
||||||
|
'csrc/punica/bgmv/vec_dtypes.cuh'
|
||||||
|
'csrc/punica/punica_ops.cu'
|
||||||
|
'csrc/punica/type_convert.h'
|
||||||
|
)
|
||||||
|
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
|
||||||
|
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
|
||||||
|
| xargs clang-format --dry-run --Werror
|
||||||
1
.github/workflows/mypy.yaml
vendored
1
.github/workflows/mypy.yaml
vendored
@@ -37,6 +37,7 @@ jobs:
|
|||||||
mypy vllm/distributed --config-file pyproject.toml
|
mypy vllm/distributed --config-file pyproject.toml
|
||||||
mypy vllm/entrypoints --config-file pyproject.toml
|
mypy vllm/entrypoints --config-file pyproject.toml
|
||||||
mypy vllm/executor --config-file pyproject.toml
|
mypy vllm/executor --config-file pyproject.toml
|
||||||
|
mypy vllm/multimodal --config-file pyproject.toml
|
||||||
mypy vllm/usage --config-file pyproject.toml
|
mypy vllm/usage --config-file pyproject.toml
|
||||||
mypy vllm/*.py --config-file pyproject.toml
|
mypy vllm/*.py --config-file pyproject.toml
|
||||||
mypy vllm/transformers_utils --config-file pyproject.toml
|
mypy vllm/transformers_utils --config-file pyproject.toml
|
||||||
|
|||||||
3
.github/workflows/publish.yml
vendored
3
.github/workflows/publish.yml
vendored
@@ -58,6 +58,9 @@ jobs:
|
|||||||
|
|
||||||
- name: Setup ccache
|
- name: Setup ccache
|
||||||
uses: hendrikmuhs/ccache-action@v1.2
|
uses: hendrikmuhs/ccache-action@v1.2
|
||||||
|
with:
|
||||||
|
create-symlink: true
|
||||||
|
key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }}
|
||||||
|
|
||||||
- name: Set up Linux Env
|
- name: Set up Linux Env
|
||||||
if: ${{ runner.os == 'Linux' }}
|
if: ${{ runner.os == 'Linux' }}
|
||||||
|
|||||||
@@ -66,19 +66,6 @@ endif()
|
|||||||
#
|
#
|
||||||
find_package(Torch REQUIRED)
|
find_package(Torch REQUIRED)
|
||||||
|
|
||||||
#
|
|
||||||
# Normally `torch.utils.cpp_extension.CUDAExtension` would add
|
|
||||||
# `libtorch_python.so` for linking against an extension. Torch's cmake
|
|
||||||
# configuration does not include this library (presumably since the cmake
|
|
||||||
# config is used for standalone C++ binaries that link against torch).
|
|
||||||
# The `libtorch_python.so` library defines some of the glue code between
|
|
||||||
# torch/python via pybind and is required by VLLM extensions for this
|
|
||||||
# reason. So, add it by manually with `find_library` using torch's
|
|
||||||
# installed library path.
|
|
||||||
#
|
|
||||||
find_library(torch_python_LIBRARY torch_python PATHS
|
|
||||||
"${TORCH_INSTALL_PREFIX}/lib")
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Forward the non-CUDA device extensions to external CMake scripts.
|
# Forward the non-CUDA device extensions to external CMake scripts.
|
||||||
#
|
#
|
||||||
@@ -167,19 +154,47 @@ set(VLLM_EXT_SRC
|
|||||||
"csrc/layernorm_kernels.cu"
|
"csrc/layernorm_kernels.cu"
|
||||||
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
|
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
|
||||||
"csrc/quantization/gptq/q_gemm.cu"
|
"csrc/quantization/gptq/q_gemm.cu"
|
||||||
"csrc/quantization/fp8/fp8_cuda_kernels.cu"
|
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
||||||
|
"csrc/quantization/fp8/common.cu"
|
||||||
"csrc/cuda_utils_kernels.cu"
|
"csrc/cuda_utils_kernels.cu"
|
||||||
"csrc/moe_align_block_size_kernels.cu"
|
"csrc/moe_align_block_size_kernels.cu"
|
||||||
"csrc/pybind.cpp")
|
"csrc/torch_bindings.cpp")
|
||||||
|
|
||||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
|
include(FetchContent)
|
||||||
|
SET(CUTLASS_ENABLE_HEADERS_ONLY=ON)
|
||||||
|
FetchContent_Declare(
|
||||||
|
cutlass
|
||||||
|
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||||
|
# CUTLASS 3.5.0
|
||||||
|
GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc
|
||||||
|
)
|
||||||
|
FetchContent_MakeAvailable(cutlass)
|
||||||
|
|
||||||
list(APPEND VLLM_EXT_SRC
|
list(APPEND VLLM_EXT_SRC
|
||||||
"csrc/quantization/aqlm/gemm_kernels.cu"
|
"csrc/quantization/aqlm/gemm_kernels.cu"
|
||||||
"csrc/quantization/awq/gemm_kernels.cu"
|
"csrc/quantization/awq/gemm_kernels.cu"
|
||||||
"csrc/quantization/marlin/marlin_cuda_kernel.cu"
|
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
||||||
|
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||||
"csrc/custom_all_reduce.cu")
|
"csrc/custom_all_reduce.cu"
|
||||||
|
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu"
|
||||||
|
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu"
|
||||||
|
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu")
|
||||||
|
|
||||||
|
#
|
||||||
|
# The CUTLASS kernels for Hopper require sm90a to be enabled.
|
||||||
|
# This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a.
|
||||||
|
# That adds an extra 17MB to compiled binary, so instead we selectively enable it.
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
|
||||||
|
set_source_files_properties(
|
||||||
|
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu"
|
||||||
|
PROPERTIES
|
||||||
|
COMPILE_FLAGS
|
||||||
|
"-gencode arch=compute_90a,code=sm_90a")
|
||||||
|
endif()
|
||||||
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
define_gpu_extension_target(
|
define_gpu_extension_target(
|
||||||
@@ -189,6 +204,8 @@ define_gpu_extension_target(
|
|||||||
SOURCES ${VLLM_EXT_SRC}
|
SOURCES ${VLLM_EXT_SRC}
|
||||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||||
|
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
||||||
|
USE_SABI 3
|
||||||
WITH_SOABI)
|
WITH_SOABI)
|
||||||
|
|
||||||
#
|
#
|
||||||
@@ -196,7 +213,7 @@ define_gpu_extension_target(
|
|||||||
#
|
#
|
||||||
|
|
||||||
set(VLLM_MOE_EXT_SRC
|
set(VLLM_MOE_EXT_SRC
|
||||||
"csrc/moe/moe_ops.cpp"
|
"csrc/moe/torch_bindings.cpp"
|
||||||
"csrc/moe/topk_softmax_kernels.cu")
|
"csrc/moe/topk_softmax_kernels.cu")
|
||||||
|
|
||||||
define_gpu_extension_target(
|
define_gpu_extension_target(
|
||||||
@@ -206,6 +223,7 @@ define_gpu_extension_target(
|
|||||||
SOURCES ${VLLM_MOE_EXT_SRC}
|
SOURCES ${VLLM_MOE_EXT_SRC}
|
||||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||||
|
USE_SABI 3
|
||||||
WITH_SOABI)
|
WITH_SOABI)
|
||||||
|
|
||||||
#
|
#
|
||||||
@@ -219,7 +237,8 @@ set(VLLM_PUNICA_EXT_SRC
|
|||||||
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
|
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
|
||||||
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
|
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
|
||||||
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
|
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
|
||||||
"csrc/punica/punica_ops.cc")
|
"csrc/punica/punica_ops.cu"
|
||||||
|
"csrc/punica/torch_bindings.cpp")
|
||||||
|
|
||||||
#
|
#
|
||||||
# Copy GPU compilation flags+update for punica
|
# Copy GPU compilation flags+update for punica
|
||||||
@@ -243,6 +262,9 @@ if (${VLLM_GPU_LANG} STREQUAL "CUDA")
|
|||||||
endif()
|
endif()
|
||||||
endforeach()
|
endforeach()
|
||||||
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
|
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
|
||||||
|
elseif(${VLLM_GPU_LANG} STREQUAL "HIP")
|
||||||
|
set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES})
|
||||||
|
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (VLLM_PUNICA_GPU_ARCHES)
|
if (VLLM_PUNICA_GPU_ARCHES)
|
||||||
@@ -253,6 +275,7 @@ if (VLLM_PUNICA_GPU_ARCHES)
|
|||||||
SOURCES ${VLLM_PUNICA_EXT_SRC}
|
SOURCES ${VLLM_PUNICA_EXT_SRC}
|
||||||
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS}
|
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS}
|
||||||
ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES}
|
ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES}
|
||||||
|
USE_SABI 3
|
||||||
WITH_SOABI)
|
WITH_SOABI)
|
||||||
else()
|
else()
|
||||||
message(WARNING "Unable to create _punica_C target because none of the "
|
message(WARNING "Unable to create _punica_C target because none of the "
|
||||||
@@ -277,9 +300,7 @@ add_custom_target(default)
|
|||||||
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||||
message(STATUS "Enabling C extension.")
|
message(STATUS "Enabling C extension.")
|
||||||
add_dependencies(default _C)
|
add_dependencies(default _C)
|
||||||
endif()
|
|
||||||
|
|
||||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|
||||||
message(STATUS "Enabling moe extension.")
|
message(STATUS "Enabling moe extension.")
|
||||||
add_dependencies(default _moe_C)
|
add_dependencies(default _moe_C)
|
||||||
|
|
||||||
|
|||||||
27
Dockerfile
27
Dockerfile
@@ -79,31 +79,8 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
|
|||||||
COPY .buildkite/check-wheel-size.py check-wheel-size.py
|
COPY .buildkite/check-wheel-size.py check-wheel-size.py
|
||||||
RUN python3 check-wheel-size.py dist
|
RUN python3 check-wheel-size.py dist
|
||||||
|
|
||||||
# the `vllm_nccl` package must be installed from source distribution
|
|
||||||
# pip is too smart to store a wheel in the cache, and other CI jobs
|
|
||||||
# will directly use the wheel from the cache, which is not what we want.
|
|
||||||
# we need to remove it manually
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
|
||||||
pip cache remove vllm_nccl*
|
|
||||||
#################### EXTENSION Build IMAGE ####################
|
#################### EXTENSION Build IMAGE ####################
|
||||||
|
|
||||||
#################### FLASH_ATTENTION Build IMAGE ####################
|
|
||||||
FROM dev as flash-attn-builder
|
|
||||||
# max jobs used for build
|
|
||||||
ARG max_jobs=2
|
|
||||||
ENV MAX_JOBS=${max_jobs}
|
|
||||||
# flash attention version
|
|
||||||
ARG flash_attn_version=v2.5.8
|
|
||||||
ENV FLASH_ATTN_VERSION=${flash_attn_version}
|
|
||||||
|
|
||||||
WORKDIR /usr/src/flash-attention-v2
|
|
||||||
|
|
||||||
# Download the wheel or build it if a pre-compiled release doesn't exist
|
|
||||||
RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \
|
|
||||||
--no-build-isolation --no-deps --no-cache-dir
|
|
||||||
|
|
||||||
#################### FLASH_ATTENTION Build IMAGE ####################
|
|
||||||
|
|
||||||
#################### vLLM installation IMAGE ####################
|
#################### vLLM installation IMAGE ####################
|
||||||
# image with vLLM installed
|
# image with vLLM installed
|
||||||
FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base
|
FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base
|
||||||
@@ -122,10 +99,6 @@ RUN ldconfig /usr/local/cuda-12.4/compat/
|
|||||||
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
|
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
|
||||||
--mount=type=cache,target=/root/.cache/pip \
|
--mount=type=cache,target=/root/.cache/pip \
|
||||||
pip install dist/*.whl --verbose
|
pip install dist/*.whl --verbose
|
||||||
|
|
||||||
RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
|
|
||||||
--mount=type=cache,target=/root/.cache/pip \
|
|
||||||
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
|
|
||||||
#################### vLLM installation IMAGE ####################
|
#################### vLLM installation IMAGE ####################
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,15 @@
|
|||||||
# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform.
|
# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform.
|
||||||
|
|
||||||
FROM ubuntu:22.04
|
FROM ubuntu:22.04 AS cpu-test-1
|
||||||
|
|
||||||
RUN apt-get update -y \
|
RUN apt-get update -y \
|
||||||
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \
|
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \
|
||||||
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
|
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
|
||||||
|
|
||||||
RUN pip install --upgrade pip \
|
RUN pip install --upgrade pip \
|
||||||
&& pip install wheel packaging ninja setuptools>=49.4.0 numpy
|
&& pip install wheel packaging ninja "setuptools>=49.4.0" numpy
|
||||||
|
|
||||||
|
FROM cpu-test-1 AS build
|
||||||
|
|
||||||
COPY ./ /workspace/vllm
|
COPY ./ /workspace/vllm
|
||||||
|
|
||||||
@@ -17,4 +19,8 @@ RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.py
|
|||||||
|
|
||||||
RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
|
RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
|
||||||
|
|
||||||
|
WORKDIR /workspace/
|
||||||
|
|
||||||
|
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
|
||||||
|
|
||||||
CMD ["/bin/bash"]
|
CMD ["/bin/bash"]
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ COPY ./requirements-neuron.txt /app/vllm/requirements-neuron.txt
|
|||||||
RUN cd /app/vllm \
|
RUN cd /app/vllm \
|
||||||
&& python3 -m pip install -U -r requirements-neuron.txt
|
&& python3 -m pip install -U -r requirements-neuron.txt
|
||||||
|
|
||||||
ENV VLLM_BUILD_WITH_NEURON 1
|
ENV VLLM_TARGET_DEVICE neuron
|
||||||
RUN cd /app/vllm \
|
RUN cd /app/vllm \
|
||||||
&& pip install -e . \
|
&& pip install -e . \
|
||||||
&& cd ..
|
&& cd ..
|
||||||
|
|||||||
@@ -92,16 +92,24 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \
|
|||||||
WORKDIR /vllm-workspace
|
WORKDIR /vllm-workspace
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
|
#RUN python3 -m pip install pynvml # to be removed eventually
|
||||||
RUN python3 -m pip install --upgrade pip numba
|
RUN python3 -m pip install --upgrade pip numba
|
||||||
|
|
||||||
|
# make sure punica kernels are built (for LoRA)
|
||||||
|
ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
||||||
|
# Workaround for ray >= 2.10.0
|
||||||
|
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
|
||||||
|
|
||||||
|
ENV VLLM_NCCL_SO_PATH=/opt/rocm/lib/librccl.so
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
pip install -U -r requirements-rocm.txt \
|
pip install -U -r requirements-rocm.txt \
|
||||||
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \
|
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \
|
||||||
&& python3 setup.py install \
|
&& python3 setup.py install \
|
||||||
&& cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \
|
&& cp build/lib.linux-x86_64-cpython-39/vllm/_C.abi3.so vllm/ \
|
||||||
|
&& cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.abi3.so vllm/ \
|
||||||
|
&& cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.abi3.so vllm/ \
|
||||||
&& cd ..
|
&& cd ..
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip
|
|
||||||
RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3
|
|
||||||
|
|
||||||
CMD ["/bin/bash"]
|
CMD ["/bin/bash"]
|
||||||
|
|||||||
90
README.md
90
README.md
@@ -14,6 +14,24 @@ Easy, fast, and cheap LLM serving for everyone
|
|||||||
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Ray Summit CPF is Open (June 4th to June 20th)!**
|
||||||
|
|
||||||
|
There will be a track for vLLM at the Ray Summit (09/30-10/02, SF) this year!
|
||||||
|
If you have cool projects related to vLLM or LLM inference, we would love to see your proposals.
|
||||||
|
This will be a great chance for everyone in the community to get together and learn.
|
||||||
|
Please submit your proposal [here](https://raysummit.anyscale.com/flow/anyscale/raysummit2024/landing/page/eventsite)
|
||||||
|
|
||||||
|
**The Fourth vLLM Bay Area Meetup (June 11th 5:30pm-8pm PT)**
|
||||||
|
|
||||||
|
We are thrilled to announce our fourth vLLM Meetup!
|
||||||
|
The vLLM team will share recent updates and roadmap.
|
||||||
|
We will also have vLLM collaborators from BentoML and Cloudflare coming up to the stage to discuss their experience in deploying LLMs with vLLM.
|
||||||
|
Please register [here](https://lu.ma/agivllm) and join us!
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
*Latest News* 🔥
|
*Latest News* 🔥
|
||||||
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
|
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
|
||||||
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
|
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
|
||||||
@@ -51,41 +69,14 @@ vLLM is flexible and easy to use with:
|
|||||||
- (Experimental) Prefix caching support
|
- (Experimental) Prefix caching support
|
||||||
- (Experimental) Multi-lora support
|
- (Experimental) Multi-lora support
|
||||||
|
|
||||||
vLLM seamlessly supports many Hugging Face models, including the following architectures:
|
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
|
||||||
|
- Transformer-like LLMs (e.g., Llama)
|
||||||
|
- Mixture-of-Expert LLMs (e.g., Mixtral)
|
||||||
|
- Multi-modal LLMs (e.g., LLaVA)
|
||||||
|
|
||||||
- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
|
Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html).
|
||||||
- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
|
|
||||||
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
|
## Getting Started
|
||||||
- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
|
|
||||||
- Command-R (`CohereForAI/c4ai-command-r-v01`, etc.)
|
|
||||||
- DBRX (`databricks/dbrx-base`, `databricks/dbrx-instruct` etc.)
|
|
||||||
- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.)
|
|
||||||
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
|
|
||||||
- Gemma (`google/gemma-2b`, `google/gemma-7b`, etc.)
|
|
||||||
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
|
|
||||||
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
|
|
||||||
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
|
|
||||||
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
|
|
||||||
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
|
|
||||||
- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.)
|
|
||||||
- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.)
|
|
||||||
- LLaMA, Llama 2, and Meta Llama 3 (`meta-llama/Meta-Llama-3-8B-Instruct`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
|
||||||
- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.)
|
|
||||||
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
|
||||||
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.)
|
|
||||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
|
||||||
- OLMo (`allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc.)
|
|
||||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
|
||||||
- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.)
|
|
||||||
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
|
|
||||||
- Phi-3 (`microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, etc.)
|
|
||||||
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
|
||||||
- Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.)
|
|
||||||
- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.)
|
|
||||||
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
|
|
||||||
- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.)
|
|
||||||
- Xverse (`xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.)
|
|
||||||
- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
|
|
||||||
|
|
||||||
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
||||||
|
|
||||||
@@ -93,9 +84,7 @@ Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/get
|
|||||||
pip install vllm
|
pip install vllm
|
||||||
```
|
```
|
||||||
|
|
||||||
## Getting Started
|
Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to learn more.
|
||||||
|
|
||||||
Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started.
|
|
||||||
- [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html)
|
- [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html)
|
||||||
- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html)
|
- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html)
|
||||||
- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)
|
- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)
|
||||||
@@ -105,6 +94,33 @@ Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started
|
|||||||
We welcome and value any contributions and collaborations.
|
We welcome and value any contributions and collaborations.
|
||||||
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
|
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
|
||||||
|
|
||||||
|
## Sponsors
|
||||||
|
|
||||||
|
vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support!
|
||||||
|
|
||||||
|
<!-- Note: Please sort them in alphabetical order. -->
|
||||||
|
<!-- Note: Please keep these consistent with docs/source/community/sponsors.md -->
|
||||||
|
|
||||||
|
- a16z
|
||||||
|
- AMD
|
||||||
|
- Anyscale
|
||||||
|
- AWS
|
||||||
|
- Crusoe Cloud
|
||||||
|
- Databricks
|
||||||
|
- DeepInfra
|
||||||
|
- Dropbox
|
||||||
|
- Lambda Lab
|
||||||
|
- NVIDIA
|
||||||
|
- Replicate
|
||||||
|
- Roblox
|
||||||
|
- RunPod
|
||||||
|
- Sequoia Capital
|
||||||
|
- Trainy
|
||||||
|
- UC Berkeley
|
||||||
|
- UC San Diego
|
||||||
|
|
||||||
|
We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM.
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
|
If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
|
||||||
|
|||||||
@@ -89,6 +89,9 @@ async def async_request_tgi(
|
|||||||
output.latency = most_recent_timestamp - st
|
output.latency = most_recent_timestamp - st
|
||||||
output.success = True
|
output.success = True
|
||||||
output.generated_text = data["generated_text"]
|
output.generated_text = data["generated_text"]
|
||||||
|
else:
|
||||||
|
output.error = response.reason or ""
|
||||||
|
output.success = False
|
||||||
except Exception:
|
except Exception:
|
||||||
output.success = False
|
output.success = False
|
||||||
exc_info = sys.exc_info()
|
exc_info = sys.exc_info()
|
||||||
@@ -276,6 +279,9 @@ async def async_request_openai_completions(
|
|||||||
output.generated_text = generated_text
|
output.generated_text = generated_text
|
||||||
output.success = True
|
output.success = True
|
||||||
output.latency = latency
|
output.latency = latency
|
||||||
|
else:
|
||||||
|
output.error = response.reason or ""
|
||||||
|
output.success = False
|
||||||
except Exception:
|
except Exception:
|
||||||
output.success = False
|
output.success = False
|
||||||
exc_info = sys.exc_info()
|
exc_info = sys.exc_info()
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
"""Benchmark the latency of processing a single batch of requests."""
|
"""Benchmark the latency of processing a single batch of requests."""
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.inputs import PromptStrictInputs
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
|
|
||||||
|
|
||||||
@@ -18,6 +20,8 @@ def main(args: argparse.Namespace):
|
|||||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||||
# the engine will automatically process the request in multiple batches.
|
# the engine will automatically process the request in multiple batches.
|
||||||
llm = LLM(model=args.model,
|
llm = LLM(model=args.model,
|
||||||
|
speculative_model=args.speculative_model,
|
||||||
|
num_speculative_tokens=args.num_speculative_tokens,
|
||||||
tokenizer=args.tokenizer,
|
tokenizer=args.tokenizer,
|
||||||
quantization=args.quantization,
|
quantization=args.quantization,
|
||||||
tensor_parallel_size=args.tensor_parallel_size,
|
tensor_parallel_size=args.tensor_parallel_size,
|
||||||
@@ -28,9 +32,12 @@ def main(args: argparse.Namespace):
|
|||||||
quantization_param_path=args.quantization_param_path,
|
quantization_param_path=args.quantization_param_path,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
ray_workers_use_nsight=args.ray_workers_use_nsight,
|
ray_workers_use_nsight=args.ray_workers_use_nsight,
|
||||||
|
use_v2_block_manager=args.use_v2_block_manager,
|
||||||
enable_chunked_prefill=args.enable_chunked_prefill,
|
enable_chunked_prefill=args.enable_chunked_prefill,
|
||||||
download_dir=args.download_dir,
|
download_dir=args.download_dir,
|
||||||
block_size=args.block_size)
|
block_size=args.block_size,
|
||||||
|
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||||
|
distributed_executor_backend=args.distributed_executor_backend)
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
n=args.n,
|
n=args.n,
|
||||||
@@ -44,7 +51,9 @@ def main(args: argparse.Namespace):
|
|||||||
dummy_prompt_token_ids = np.random.randint(10000,
|
dummy_prompt_token_ids = np.random.randint(10000,
|
||||||
size=(args.batch_size,
|
size=(args.batch_size,
|
||||||
args.input_len))
|
args.input_len))
|
||||||
dummy_prompt_token_ids = dummy_prompt_token_ids.tolist()
|
dummy_inputs: List[PromptStrictInputs] = [{
|
||||||
|
"prompt_token_ids": batch
|
||||||
|
} for batch in dummy_prompt_token_ids.tolist()]
|
||||||
|
|
||||||
def run_to_completion(profile_dir: Optional[str] = None):
|
def run_to_completion(profile_dir: Optional[str] = None):
|
||||||
if profile_dir:
|
if profile_dir:
|
||||||
@@ -55,13 +64,13 @@ def main(args: argparse.Namespace):
|
|||||||
],
|
],
|
||||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||||
str(profile_dir))) as p:
|
str(profile_dir))) as p:
|
||||||
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
llm.generate(dummy_inputs,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
use_tqdm=False)
|
use_tqdm=False)
|
||||||
print(p.key_averages())
|
print(p.key_averages())
|
||||||
else:
|
else:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
llm.generate(dummy_inputs,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
use_tqdm=False)
|
use_tqdm=False)
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
@@ -93,12 +102,24 @@ def main(args: argparse.Namespace):
|
|||||||
for percentage, percentile in zip(percentages, percentiles):
|
for percentage, percentile in zip(percentages, percentiles):
|
||||||
print(f'{percentage}% percentile latency: {percentile} seconds')
|
print(f'{percentage}% percentile latency: {percentile} seconds')
|
||||||
|
|
||||||
|
# Output JSON results if specified
|
||||||
|
if args.output_json:
|
||||||
|
results = {
|
||||||
|
"avg_latency": np.mean(latencies),
|
||||||
|
"latencies": latencies.tolist(),
|
||||||
|
"percentiles": dict(zip(percentages, percentiles.tolist())),
|
||||||
|
}
|
||||||
|
with open(args.output_json, "w") as f:
|
||||||
|
json.dump(results, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='Benchmark the latency of processing a single batch of '
|
description='Benchmark the latency of processing a single batch of '
|
||||||
'requests till completion.')
|
'requests till completion.')
|
||||||
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
||||||
|
parser.add_argument('--speculative-model', type=str, default=None)
|
||||||
|
parser.add_argument('--num-speculative-tokens', type=int, default=None)
|
||||||
parser.add_argument('--tokenizer', type=str, default=None)
|
parser.add_argument('--tokenizer', type=str, default=None)
|
||||||
parser.add_argument('--quantization',
|
parser.add_argument('--quantization',
|
||||||
'-q',
|
'-q',
|
||||||
@@ -137,15 +158,13 @@ if __name__ == '__main__':
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
help='enforce eager mode and disable CUDA graph')
|
help='enforce eager mode and disable CUDA graph')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kv-cache-dtype",
|
'--kv-cache-dtype',
|
||||||
type=str,
|
type=str,
|
||||||
choices=['auto', 'fp8'],
|
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
|
||||||
default='auto',
|
default="auto",
|
||||||
help=
|
help='Data type for kv cache storage. If "auto", will use model '
|
||||||
'Data type for kv cache storage. If "auto", will use model data type. '
|
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
|
||||||
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
|
||||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
|
||||||
'common inference criteria.')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--quantization-param-path',
|
'--quantization-param-path',
|
||||||
type=str,
|
type=str,
|
||||||
@@ -181,6 +200,7 @@ if __name__ == '__main__':
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
help='If True, the prefill requests can be chunked based on the '
|
help='If True, the prefill requests can be chunked based on the '
|
||||||
'max_num_batched_tokens')
|
'max_num_batched_tokens')
|
||||||
|
parser.add_argument('--use-v2-block-manager', action='store_true')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ray-workers-use-nsight",
|
"--ray-workers-use-nsight",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
@@ -191,5 +211,23 @@ if __name__ == '__main__':
|
|||||||
default=None,
|
default=None,
|
||||||
help='directory to download and load the weights, '
|
help='directory to download and load the weights, '
|
||||||
'default to the default cache dir of huggingface')
|
'default to the default cache dir of huggingface')
|
||||||
|
parser.add_argument(
|
||||||
|
'--output-json',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='Path to save the latency results in JSON format.')
|
||||||
|
parser.add_argument('--gpu-memory-utilization',
|
||||||
|
type=float,
|
||||||
|
default=0.9,
|
||||||
|
help='the fraction of GPU memory to be used for '
|
||||||
|
'the model executor, which can range from 0 to 1.'
|
||||||
|
'If unspecified, will use the default value of 0.9.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--distributed-executor-backend',
|
||||||
|
choices=['ray', 'mp'],
|
||||||
|
default=None,
|
||||||
|
help='Backend to use for distributed serving. When more than 1 GPU '
|
||||||
|
'is used, will be automatically set to "ray" if installed '
|
||||||
|
'or "mp" (multiprocessing) otherwise.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@@ -17,6 +17,10 @@ On the client side, run:
|
|||||||
--dataset-path <path to dataset> \
|
--dataset-path <path to dataset> \
|
||||||
--request-rate <request_rate> \ # By default <request_rate> is inf
|
--request-rate <request_rate> \ # By default <request_rate> is inf
|
||||||
--num-prompts <num_prompts> # By default <num_prompts> is 1000
|
--num-prompts <num_prompts> # By default <num_prompts> is 1000
|
||||||
|
|
||||||
|
when using tgi backend, add
|
||||||
|
--endpoint /generate_stream
|
||||||
|
to the end of the command above.
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -52,6 +56,9 @@ class BenchmarkMetrics:
|
|||||||
mean_tpot_ms: float
|
mean_tpot_ms: float
|
||||||
median_tpot_ms: float
|
median_tpot_ms: float
|
||||||
p99_tpot_ms: float
|
p99_tpot_ms: float
|
||||||
|
mean_itl_ms: float
|
||||||
|
median_itl_ms: float
|
||||||
|
p99_itl_ms: float
|
||||||
|
|
||||||
|
|
||||||
def sample_sharegpt_requests(
|
def sample_sharegpt_requests(
|
||||||
@@ -196,21 +203,34 @@ def calculate_metrics(
|
|||||||
actual_output_lens = []
|
actual_output_lens = []
|
||||||
total_input = 0
|
total_input = 0
|
||||||
completed = 0
|
completed = 0
|
||||||
|
itls = []
|
||||||
tpots = []
|
tpots = []
|
||||||
ttfts = []
|
ttfts = []
|
||||||
for i in range(len(outputs)):
|
for i in range(len(outputs)):
|
||||||
if outputs[i].success:
|
if outputs[i].success:
|
||||||
output_len = len(tokenizer(outputs[i].generated_text).input_ids)
|
# We use the tokenizer to count the number of output tokens for all
|
||||||
|
# serving backends instead of looking at len(outputs[i].itl) since
|
||||||
|
# multiple output tokens may be bundled together
|
||||||
|
# Note: this may inflate the output token count slightly
|
||||||
|
output_len = len(
|
||||||
|
tokenizer(outputs[i].generated_text,
|
||||||
|
add_special_tokens=False).input_ids)
|
||||||
actual_output_lens.append(output_len)
|
actual_output_lens.append(output_len)
|
||||||
total_input += input_requests[i][1]
|
total_input += input_requests[i][1]
|
||||||
if output_len > 1:
|
if output_len > 1:
|
||||||
tpots.append(
|
tpots.append(
|
||||||
(outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
(outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
||||||
|
itls += outputs[i].itl
|
||||||
ttfts.append(outputs[i].ttft)
|
ttfts.append(outputs[i].ttft)
|
||||||
completed += 1
|
completed += 1
|
||||||
else:
|
else:
|
||||||
actual_output_lens.append(0)
|
actual_output_lens.append(0)
|
||||||
|
|
||||||
|
if completed == 0:
|
||||||
|
warnings.warn(
|
||||||
|
"All requests failed. This is likely due to a misconfiguration "
|
||||||
|
"on the benchmark arguments.",
|
||||||
|
stacklevel=2)
|
||||||
metrics = BenchmarkMetrics(
|
metrics = BenchmarkMetrics(
|
||||||
completed=completed,
|
completed=completed,
|
||||||
total_input=total_input,
|
total_input=total_input,
|
||||||
@@ -222,9 +242,12 @@ def calculate_metrics(
|
|||||||
1000, # ttfts is empty if streaming is not supported by backend
|
1000, # ttfts is empty if streaming is not supported by backend
|
||||||
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
||||||
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
|
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
|
||||||
mean_tpot_ms=np.mean(tpots) * 1000,
|
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
||||||
median_tpot_ms=np.median(tpots) * 1000,
|
median_tpot_ms=np.median(tpots or 0) * 1000,
|
||||||
p99_tpot_ms=np.percentile(tpots, 99) * 1000,
|
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
|
||||||
|
mean_itl_ms=np.mean(itls or 0) * 1000,
|
||||||
|
median_itl_ms=np.median(itls or 0) * 1000,
|
||||||
|
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
return metrics, actual_output_lens
|
return metrics, actual_output_lens
|
||||||
@@ -246,6 +269,24 @@ async def benchmark(
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown backend: {backend}")
|
raise ValueError(f"Unknown backend: {backend}")
|
||||||
|
|
||||||
|
print("Starting initial single prompt test run...")
|
||||||
|
test_prompt, test_prompt_len, test_output_len = input_requests[0]
|
||||||
|
test_input = RequestFuncInput(
|
||||||
|
model=model_id,
|
||||||
|
prompt=test_prompt,
|
||||||
|
api_url=api_url,
|
||||||
|
prompt_len=test_prompt_len,
|
||||||
|
output_len=test_output_len,
|
||||||
|
best_of=best_of,
|
||||||
|
use_beam_search=use_beam_search,
|
||||||
|
)
|
||||||
|
test_output = await request_func(request_func_input=test_input)
|
||||||
|
if not test_output.success:
|
||||||
|
raise ValueError(
|
||||||
|
"Initial test run failed - Please make sure benchmark arguments "
|
||||||
|
f"are correctly specified. Error: {test_output.error}")
|
||||||
|
else:
|
||||||
|
print("Initial test run completed. Starting main benchmark run...")
|
||||||
print(f"Traffic request rate: {request_rate}")
|
print(f"Traffic request rate: {request_rate}")
|
||||||
|
|
||||||
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
||||||
@@ -306,6 +347,10 @@ async def benchmark(
|
|||||||
print("{:<40} {:<10.2f}".format("Median TPOT (ms):",
|
print("{:<40} {:<10.2f}".format("Median TPOT (ms):",
|
||||||
metrics.median_tpot_ms))
|
metrics.median_tpot_ms))
|
||||||
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
||||||
|
print("{s:{c}^{n}}".format(s='Inter-token Latency', n=50, c='-'))
|
||||||
|
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
||||||
|
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
||||||
|
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
@@ -322,6 +367,9 @@ async def benchmark(
|
|||||||
"mean_tpot_ms": metrics.mean_tpot_ms,
|
"mean_tpot_ms": metrics.mean_tpot_ms,
|
||||||
"median_tpot_ms": metrics.median_tpot_ms,
|
"median_tpot_ms": metrics.median_tpot_ms,
|
||||||
"p99_tpot_ms": metrics.p99_tpot_ms,
|
"p99_tpot_ms": metrics.p99_tpot_ms,
|
||||||
|
"mean_itl_ms": metrics.mean_itl_ms,
|
||||||
|
"median_itl_ms": metrics.median_itl_ms,
|
||||||
|
"p99_itl_ms": metrics.p99_itl_ms,
|
||||||
"input_lens": [output.prompt_len for output in outputs],
|
"input_lens": [output.prompt_len for output in outputs],
|
||||||
"output_lens": actual_output_lens,
|
"output_lens": actual_output_lens,
|
||||||
"ttfts": [output.ttft for output in outputs],
|
"ttfts": [output.ttft for output in outputs],
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ def run_vllm(
|
|||||||
enable_prefix_caching: bool,
|
enable_prefix_caching: bool,
|
||||||
enable_chunked_prefill: bool,
|
enable_chunked_prefill: bool,
|
||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
|
distributed_executor_backend: Optional[str],
|
||||||
gpu_memory_utilization: float = 0.9,
|
gpu_memory_utilization: float = 0.9,
|
||||||
download_dir: Optional[str] = None,
|
download_dir: Optional[str] = None,
|
||||||
) -> float:
|
) -> float:
|
||||||
@@ -100,6 +101,7 @@ def run_vllm(
|
|||||||
download_dir=download_dir,
|
download_dir=download_dir,
|
||||||
enable_chunked_prefill=enable_chunked_prefill,
|
enable_chunked_prefill=enable_chunked_prefill,
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
@@ -225,8 +227,8 @@ def main(args: argparse.Namespace):
|
|||||||
args.enforce_eager, args.kv_cache_dtype,
|
args.enforce_eager, args.kv_cache_dtype,
|
||||||
args.quantization_param_path, args.device,
|
args.quantization_param_path, args.device,
|
||||||
args.enable_prefix_caching, args.enable_chunked_prefill,
|
args.enable_prefix_caching, args.enable_chunked_prefill,
|
||||||
args.max_num_batched_tokens, args.gpu_memory_utilization,
|
args.max_num_batched_tokens, args.distributed_executor_backend,
|
||||||
args.download_dir)
|
args.gpu_memory_utilization, args.download_dir)
|
||||||
elif args.backend == "hf":
|
elif args.backend == "hf":
|
||||||
assert args.tensor_parallel_size == 1
|
assert args.tensor_parallel_size == 1
|
||||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||||
@@ -242,6 +244,18 @@ def main(args: argparse.Namespace):
|
|||||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||||
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
||||||
|
|
||||||
|
# Output JSON results if specified
|
||||||
|
if args.output_json:
|
||||||
|
results = {
|
||||||
|
"elapsed_time": elapsed_time,
|
||||||
|
"num_requests": len(requests),
|
||||||
|
"total_num_tokens": total_num_tokens,
|
||||||
|
"requests_per_second": len(requests) / elapsed_time,
|
||||||
|
"tokens_per_second": total_num_tokens / elapsed_time,
|
||||||
|
}
|
||||||
|
with open(args.output_json, "w") as f:
|
||||||
|
json.dump(results, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
||||||
@@ -311,15 +325,13 @@ if __name__ == "__main__":
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="enforce eager execution")
|
help="enforce eager execution")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kv-cache-dtype",
|
'--kv-cache-dtype',
|
||||||
type=str,
|
type=str,
|
||||||
choices=["auto", "fp8"],
|
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
|
||||||
default="auto",
|
default="auto",
|
||||||
help=
|
help='Data type for kv cache storage. If "auto", will use model '
|
||||||
'Data type for kv cache storage. If "auto", will use model data type. '
|
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
|
||||||
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
|
||||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
|
||||||
'common inference criteria.')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--quantization-param-path',
|
'--quantization-param-path',
|
||||||
type=str,
|
type=str,
|
||||||
@@ -353,6 +365,18 @@ if __name__ == "__main__":
|
|||||||
default=None,
|
default=None,
|
||||||
help='directory to download and load the weights, '
|
help='directory to download and load the weights, '
|
||||||
'default to the default cache dir of huggingface')
|
'default to the default cache dir of huggingface')
|
||||||
|
parser.add_argument(
|
||||||
|
'--output-json',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='Path to save the throughput results in JSON format.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--distributed-executor-backend',
|
||||||
|
choices=['ray', 'mp'],
|
||||||
|
default=None,
|
||||||
|
help='Backend to use for distributed serving. When more than 1 GPU '
|
||||||
|
'is used, will be automatically set to "ray" if installed '
|
||||||
|
'or "mp" (multiprocessing) otherwise.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.tokenizer is None:
|
if args.tokenizer is None:
|
||||||
args.tokenizer = args.model
|
args.tokenizer = args.model
|
||||||
|
|||||||
352
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
Normal file
352
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
Normal file
@@ -0,0 +1,352 @@
|
|||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
|
import pickle as pkl
|
||||||
|
import time
|
||||||
|
from typing import Callable, Iterable, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.benchmark as TBenchmark
|
||||||
|
from torch.utils.benchmark import Measurement as TMeasurement
|
||||||
|
from weight_shapes import WEIGHT_SHAPES
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
|
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:]
|
||||||
|
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||||
|
DEFAULT_TP_SIZES = [1]
|
||||||
|
|
||||||
|
# helpers
|
||||||
|
|
||||||
|
|
||||||
|
def to_fp8(tensor: torch.tensor) -> torch.tensor:
|
||||||
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
return torch.round(tensor.clamp(
|
||||||
|
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
|
||||||
|
def to_int8(tensor: torch.tensor) -> torch.tensor:
|
||||||
|
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||||
|
|
||||||
|
|
||||||
|
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
||||||
|
k: int) -> Tuple[torch.tensor, torch.tensor]:
|
||||||
|
|
||||||
|
a = torch.randn((m, k), device='cuda') * 5
|
||||||
|
b = torch.randn((n, k), device='cuda').t() * 5
|
||||||
|
|
||||||
|
if dtype == torch.int8:
|
||||||
|
return to_int8(a), to_int8(b)
|
||||||
|
if dtype == torch.float8_e4m3fn:
|
||||||
|
return to_fp8(a), to_fp8(b)
|
||||||
|
|
||||||
|
raise ValueError("unsupported dtype")
|
||||||
|
|
||||||
|
|
||||||
|
# impl
|
||||||
|
|
||||||
|
|
||||||
|
def pytorch_i8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
||||||
|
scale_b: torch.tensor,
|
||||||
|
out_dtype: torch.dtype) -> torch.tensor:
|
||||||
|
return torch.mm(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
||||||
|
scale_b: torch.tensor,
|
||||||
|
out_dtype: torch.dtype) -> torch.tensor:
|
||||||
|
return torch._scaled_mm(a,
|
||||||
|
b,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
out_dtype=out_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
|
||||||
|
scale_a: torch.tensor, scale_b: torch.tensor,
|
||||||
|
out_dtype: torch.dtype) -> torch.tensor:
|
||||||
|
return torch._scaled_mm(a,
|
||||||
|
b,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
use_fast_accum=True)
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
||||||
|
scale_b: torch.tensor,
|
||||||
|
out_dtype: torch.dtype) -> torch.tensor:
|
||||||
|
return ops.cutlass_scaled_mm_dq(a,
|
||||||
|
b,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
out_dtype=out_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
# bench
|
||||||
|
def bench_fn(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
||||||
|
scale_b: torch.tensor, out_dtype: torch.dtype, label: str,
|
||||||
|
sub_label: str, fn: Callable, description: str) -> TMeasurement:
|
||||||
|
|
||||||
|
min_run_time = 1
|
||||||
|
|
||||||
|
globals = {
|
||||||
|
"a": a,
|
||||||
|
"b": b,
|
||||||
|
"scale_a": scale_a,
|
||||||
|
"scale_b": scale_b,
|
||||||
|
"out_dtype": out_dtype,
|
||||||
|
"fn": fn,
|
||||||
|
}
|
||||||
|
return TBenchmark.Timer(
|
||||||
|
stmt="fn(a, b, scale_a, scale_b, out_dtype)",
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description=description,
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
|
||||||
|
|
||||||
|
def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||||
|
sub_label: str) -> Iterable[TMeasurement]:
|
||||||
|
assert dtype == torch.int8
|
||||||
|
a, b = make_rand_tensors(torch.int8, m, n, k)
|
||||||
|
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
|
||||||
|
timers = []
|
||||||
|
# pytorch impl
|
||||||
|
timers.append(
|
||||||
|
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
|
||||||
|
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
|
||||||
|
torch.bfloat16, label, sub_label, pytorch_i8_impl,
|
||||||
|
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
|
||||||
|
|
||||||
|
# cutlass impl
|
||||||
|
timers.append(
|
||||||
|
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
|
||||||
|
torch.bfloat16, label, sub_label, cutlass_impl,
|
||||||
|
"cutlass_i8_i8_bf16_scaled_mm"))
|
||||||
|
|
||||||
|
return timers
|
||||||
|
|
||||||
|
|
||||||
|
def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||||
|
sub_label: str) -> Iterable[TMeasurement]:
|
||||||
|
assert dtype == torch.float8_e4m3fn
|
||||||
|
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
|
||||||
|
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
|
||||||
|
timers = []
|
||||||
|
|
||||||
|
# pytorch impl: bf16 output, without fp8 fast accum
|
||||||
|
timers.append(
|
||||||
|
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
||||||
|
pytorch_fp8_impl, "pytorch_fp8_fp8_bf16_scaled_mm"))
|
||||||
|
|
||||||
|
# pytorch impl: bf16 output, with fp8 fast accum
|
||||||
|
timers.append(
|
||||||
|
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
||||||
|
pytorch_fp8_impl_fast_accum,
|
||||||
|
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"))
|
||||||
|
|
||||||
|
# pytorch impl: fp16 output, without fp8 fast accum
|
||||||
|
timers.append(
|
||||||
|
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
|
||||||
|
pytorch_fp8_impl, "pytorch_fp8_fp8_fp16_scaled_mm"))
|
||||||
|
|
||||||
|
# pytorch impl: fp16 output, with fp8 fast accum
|
||||||
|
timers.append(
|
||||||
|
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
|
||||||
|
pytorch_fp8_impl_fast_accum,
|
||||||
|
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"))
|
||||||
|
|
||||||
|
# cutlass impl: bf16 output
|
||||||
|
timers.append(
|
||||||
|
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
|
||||||
|
torch.bfloat16, label, sub_label, cutlass_impl,
|
||||||
|
"cutlass_fp8_fp8_bf16_scaled_mm"))
|
||||||
|
# cutlass impl: fp16 output
|
||||||
|
timers.append(
|
||||||
|
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
|
||||||
|
torch.float16, label, sub_label, cutlass_impl,
|
||||||
|
"cutlass_fp8_fp8_fp16_scaled_mm"))
|
||||||
|
return timers
|
||||||
|
|
||||||
|
|
||||||
|
def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||||
|
sub_label: str) -> Iterable[TMeasurement]:
|
||||||
|
if dtype == torch.int8:
|
||||||
|
return bench_int8(dtype, m, k, n, label, sub_label)
|
||||||
|
if dtype == torch.float8_e4m3fn:
|
||||||
|
return bench_fp8(dtype, m, k, n, label, sub_label)
|
||||||
|
raise ValueError("unsupported type")
|
||||||
|
|
||||||
|
|
||||||
|
# runner
|
||||||
|
def print_timers(timers: Iterable[TMeasurement]):
|
||||||
|
compare = TBenchmark.Compare(timers)
|
||||||
|
compare.print()
|
||||||
|
|
||||||
|
|
||||||
|
def run(dtype: torch.dtype,
|
||||||
|
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for m, k, n in MKNs:
|
||||||
|
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
|
||||||
|
f"MKN=({m}x{k}x{n})")
|
||||||
|
print_timers(timers)
|
||||||
|
results.extend(timers)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# output makers
|
||||||
|
def make_output(data: Iterable[TMeasurement],
|
||||||
|
MKNs: Iterable[Tuple[int, int, int]],
|
||||||
|
base_description: str,
|
||||||
|
timestamp=None):
|
||||||
|
|
||||||
|
print(f"== All Results {base_description} ====")
|
||||||
|
print_timers(data)
|
||||||
|
|
||||||
|
# pickle all the results
|
||||||
|
timestamp = int(time.time()) if timestamp is None else timestamp
|
||||||
|
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
|
||||||
|
pkl.dump(data, f)
|
||||||
|
|
||||||
|
|
||||||
|
# argparse runners
|
||||||
|
|
||||||
|
|
||||||
|
def run_square_bench(args):
|
||||||
|
dim_sizes = list(
|
||||||
|
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||||
|
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||||
|
data = run(args.dtype, MKNs)
|
||||||
|
|
||||||
|
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_range_bench(args):
|
||||||
|
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
|
||||||
|
n = len(dim_sizes)
|
||||||
|
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
|
||||||
|
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
|
||||||
|
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
|
||||||
|
MKNs = list(zip(Ms, Ks, Ns))
|
||||||
|
data = run(args.dtype, MKNs)
|
||||||
|
|
||||||
|
make_output(data, MKNs, f"range_bench-{args.dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_model_bench(args):
|
||||||
|
|
||||||
|
print("Benchmarking models:")
|
||||||
|
for i, model in enumerate(args.models):
|
||||||
|
print(f"[{i}] {model}")
|
||||||
|
|
||||||
|
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
|
||||||
|
KNs = []
|
||||||
|
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
|
||||||
|
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||||
|
KNs.append(KN)
|
||||||
|
return KNs
|
||||||
|
|
||||||
|
model_bench_data = []
|
||||||
|
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||||
|
for model, tp_size in models_tps:
|
||||||
|
Ms = args.batch_sizes
|
||||||
|
KNs = model_shapes(model, tp_size)
|
||||||
|
MKNs = []
|
||||||
|
for m in Ms:
|
||||||
|
for k, n in KNs:
|
||||||
|
MKNs.append((m, k, n))
|
||||||
|
|
||||||
|
data = run(args.dtype, MKNs)
|
||||||
|
model_bench_data.append(data)
|
||||||
|
|
||||||
|
# Print all results
|
||||||
|
for data, model_tp in zip(model_bench_data, models_tps):
|
||||||
|
model, tp_size = model_tp
|
||||||
|
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
|
||||||
|
print_timers(data)
|
||||||
|
|
||||||
|
timestamp = int(time.time())
|
||||||
|
|
||||||
|
all_data = []
|
||||||
|
for d in model_bench_data:
|
||||||
|
all_data.extend(d)
|
||||||
|
# pickle all data
|
||||||
|
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
|
||||||
|
pkl.dump(all_data, f)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
def to_torch_dtype(dt):
|
||||||
|
if dt == "int8":
|
||||||
|
return torch.int8
|
||||||
|
if dt == "fp8":
|
||||||
|
return torch.float8_e4m3fn
|
||||||
|
raise ValueError("unsupported dtype")
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="""
|
||||||
|
Benchmark Cutlass GEMM.
|
||||||
|
|
||||||
|
To run square GEMMs:
|
||||||
|
python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
|
||||||
|
|
||||||
|
To run constant N and K and sweep M:
|
||||||
|
python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
|
||||||
|
|
||||||
|
To run dimensions from a model:
|
||||||
|
python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
|
||||||
|
|
||||||
|
Output:
|
||||||
|
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
|
||||||
|
""", # noqa: E501
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter)
|
||||||
|
|
||||||
|
parser.add_argument("--dtype",
|
||||||
|
type=to_torch_dtype,
|
||||||
|
required=True,
|
||||||
|
help="Available options are ['int8', 'fp8']")
|
||||||
|
subparsers = parser.add_subparsers(dest="cmd")
|
||||||
|
|
||||||
|
square_parser = subparsers.add_parser("square_bench")
|
||||||
|
square_parser.add_argument("--dim-start", type=int, required=True)
|
||||||
|
square_parser.add_argument("--dim-end", type=int, required=True)
|
||||||
|
square_parser.add_argument("--dim-increment", type=int, required=True)
|
||||||
|
square_parser.set_defaults(func=run_square_bench)
|
||||||
|
|
||||||
|
range_parser = subparsers.add_parser("range_bench")
|
||||||
|
range_parser.add_argument("--dim-start", type=int, required=True)
|
||||||
|
range_parser.add_argument("--dim-end", type=int, required=True)
|
||||||
|
range_parser.add_argument("--dim-increment", type=int, required=True)
|
||||||
|
range_parser.add_argument("--m-constant", type=int, default=None)
|
||||||
|
range_parser.add_argument("--n-constant", type=int, default=None)
|
||||||
|
range_parser.add_argument("--k-constant", type=int, default=None)
|
||||||
|
range_parser.set_defaults(func=run_range_bench)
|
||||||
|
|
||||||
|
model_parser = subparsers.add_parser("model_bench")
|
||||||
|
model_parser.add_argument("--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_MODELS,
|
||||||
|
choices=WEIGHT_SHAPES.keys())
|
||||||
|
model_parser.add_argument("--tp-sizes",
|
||||||
|
nargs="+",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_TP_SIZES)
|
||||||
|
model_parser.add_argument("--batch-sizes",
|
||||||
|
nargs="+",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_BATCH_SIZES)
|
||||||
|
model_parser.set_defaults(func=run_model_bench)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.func(args)
|
||||||
37
benchmarks/cutlass_benchmarks/weight_shapes.py
Normal file
37
benchmarks/cutlass_benchmarks/weight_shapes.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
# Weight Shapes are in the format
|
||||||
|
# ([K, N], TP_SPLIT_DIM)
|
||||||
|
# Example:
|
||||||
|
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
|
||||||
|
# - TP1 : K = 14336, N = 4096
|
||||||
|
# - TP2 : K = 7168, N = 4096
|
||||||
|
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
|
||||||
|
# - TP1 : K = 4096, N = 6144
|
||||||
|
# - TP4 : K = 4096, N = 1536
|
||||||
|
|
||||||
|
# TP1 shapes
|
||||||
|
WEIGHT_SHAPES = {
|
||||||
|
"mistralai/Mistral-7B-v0.1": [
|
||||||
|
([4096, 6144], 1),
|
||||||
|
([4096, 4096], 0),
|
||||||
|
([4096, 28672], 1),
|
||||||
|
([14336, 4096], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-7b-hf": [
|
||||||
|
([4096, 12288], 1),
|
||||||
|
([4096, 4096], 0),
|
||||||
|
([4096, 22016], 1),
|
||||||
|
([11008, 4096], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-13b-hf": [
|
||||||
|
([5120, 15360], 1),
|
||||||
|
([5120, 5120], 0),
|
||||||
|
([5120, 27648], 1),
|
||||||
|
([13824, 5120], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-70b-hf": [
|
||||||
|
([8192, 10240], 1),
|
||||||
|
([8192, 8192], 0),
|
||||||
|
([8192, 57344], 1),
|
||||||
|
([28672, 8192], 0),
|
||||||
|
],
|
||||||
|
}
|
||||||
233
benchmarks/kernels/benchmark_marlin.py
Normal file
233
benchmarks/kernels/benchmark_marlin.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.benchmark as benchmark
|
||||||
|
from benchmark_shapes import WEIGHT_SHAPES
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||||
|
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||||
|
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
||||||
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||||
|
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||||
|
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
|
MarlinWorkspace, marlin_24_quantize, marlin_quantize)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
gptq_pack, quantize_weights, sort_weights)
|
||||||
|
|
||||||
|
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
|
||||||
|
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||||
|
|
||||||
|
ACT_ORDER_OPTS = [False, True]
|
||||||
|
K_FULL_OPTS = [False, True]
|
||||||
|
|
||||||
|
|
||||||
|
def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
|
||||||
|
size_m, size_k, size_n):
|
||||||
|
label = "Quant Matmul"
|
||||||
|
|
||||||
|
sub_label = ("{}, act={} k_full={}, b={}, g={}, "
|
||||||
|
"MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits,
|
||||||
|
group_size, size_m, size_k, size_n))
|
||||||
|
|
||||||
|
print(f"Testing: {sub_label}")
|
||||||
|
|
||||||
|
a = torch.randn(size_m, size_k).to(torch.half).cuda()
|
||||||
|
b = torch.rand(size_k, size_n).to(torch.half).cuda()
|
||||||
|
|
||||||
|
a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda())
|
||||||
|
|
||||||
|
# Marlin quant
|
||||||
|
(
|
||||||
|
marlin_w_ref,
|
||||||
|
marlin_q_w,
|
||||||
|
marlin_s,
|
||||||
|
marlin_g_idx,
|
||||||
|
marlin_sort_indices,
|
||||||
|
marlin_rand_perm,
|
||||||
|
) = marlin_quantize(b, num_bits, group_size, act_order)
|
||||||
|
|
||||||
|
# Marlin_24 quant
|
||||||
|
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
|
||||||
|
marlin_24_s) = marlin_24_quantize(b, num_bits, group_size)
|
||||||
|
|
||||||
|
# GPTQ quant
|
||||||
|
(w_ref, q_w, s, g_idx,
|
||||||
|
rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
|
||||||
|
q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n)
|
||||||
|
|
||||||
|
# For act_order, sort the "weights" and "g_idx"
|
||||||
|
# so that group ids are increasing
|
||||||
|
repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
|
||||||
|
if act_order:
|
||||||
|
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
|
||||||
|
|
||||||
|
# Prepare
|
||||||
|
marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||||
|
GPTQ_MARLIN_MAX_PARALLEL)
|
||||||
|
|
||||||
|
marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||||
|
GPTQ_MARLIN_24_MAX_PARALLEL)
|
||||||
|
|
||||||
|
globals = {
|
||||||
|
# Gen params
|
||||||
|
"num_bits": num_bits,
|
||||||
|
"group_size": group_size,
|
||||||
|
"size_m": size_m,
|
||||||
|
"size_n": size_n,
|
||||||
|
"size_k": size_k,
|
||||||
|
"a": a,
|
||||||
|
"a_tmp": a_tmp,
|
||||||
|
# Marlin params
|
||||||
|
"marlin_w_ref": marlin_w_ref,
|
||||||
|
"marlin_q_w": marlin_q_w,
|
||||||
|
"marlin_s": marlin_s,
|
||||||
|
"marlin_g_idx": marlin_g_idx,
|
||||||
|
"marlin_sort_indices": marlin_sort_indices,
|
||||||
|
"marlin_rand_perm": marlin_rand_perm,
|
||||||
|
"marlin_workspace": marlin_workspace,
|
||||||
|
"is_k_full": is_k_full,
|
||||||
|
# Marlin_24 params
|
||||||
|
"marlin_24_w_ref": marlin_24_w_ref,
|
||||||
|
"marlin_24_q_w_comp": marlin_24_q_w_comp,
|
||||||
|
"marlin_24_meta": marlin_24_meta,
|
||||||
|
"marlin_24_s": marlin_24_s,
|
||||||
|
"marlin_24_workspace": marlin_24_workspace,
|
||||||
|
# GPTQ params
|
||||||
|
"q_w_gptq": q_w_gptq,
|
||||||
|
"repack_sort_indices": repack_sort_indices,
|
||||||
|
# Kernels
|
||||||
|
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
|
||||||
|
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
|
||||||
|
"gptq_marlin_repack": ops.gptq_marlin_repack,
|
||||||
|
}
|
||||||
|
|
||||||
|
min_run_time = 1
|
||||||
|
|
||||||
|
# Warmup pytorch
|
||||||
|
for i in range(5):
|
||||||
|
torch.matmul(a, marlin_w_ref)
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt="torch.matmul(a, marlin_w_ref)",
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="pytorch_gemm",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt=
|
||||||
|
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="gptq_marlin_gemm",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
|
|
||||||
|
if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
|
||||||
|
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt=
|
||||||
|
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="gptq_marlin_24_gemm",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt=
|
||||||
|
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="gptq_marlin_repack",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print("Benchmarking models:")
|
||||||
|
for i, model in enumerate(args.models):
|
||||||
|
print(f"[{i}] {model}")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for model in args.models:
|
||||||
|
for layer in WEIGHT_SHAPES[model]:
|
||||||
|
size_k = layer[0]
|
||||||
|
size_n = layer[1]
|
||||||
|
|
||||||
|
if len(args.limit_k) > 0 and size_k not in args.limit_k:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(args.limit_n) > 0 and size_n not in args.limit_n:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for act_order in ACT_ORDER_OPTS:
|
||||||
|
if len(args.limit_act_order
|
||||||
|
) > 0 and act_order not in args.limit_act_order:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for is_k_full in K_FULL_OPTS:
|
||||||
|
if len(args.limit_k_full
|
||||||
|
) > 0 and is_k_full not in args.limit_k_full:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
|
||||||
|
if len(args.limit_num_bits
|
||||||
|
) > 0 and num_bits not in args.limit_num_bits:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
|
||||||
|
if len(
|
||||||
|
args.limit_group_size
|
||||||
|
) > 0 and group_size not in args.limit_group_size:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# For act_order, the group_size must be less than
|
||||||
|
# size_k
|
||||||
|
if act_order and (group_size == size_k
|
||||||
|
or group_size == -1):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for size_m in args.batch_sizes:
|
||||||
|
bench_run(results, model, act_order, is_k_full,
|
||||||
|
num_bits, group_size, size_m, size_k,
|
||||||
|
size_n)
|
||||||
|
|
||||||
|
compare = benchmark.Compare(results)
|
||||||
|
compare.print()
|
||||||
|
|
||||||
|
|
||||||
|
# For quick benchmarking use:
|
||||||
|
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
|
||||||
|
#
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Benchmark Marlin across specified models/shapes/batches")
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_MODELS,
|
||||||
|
choices=WEIGHT_SHAPES.keys(),
|
||||||
|
)
|
||||||
|
parser.add_argument("--batch-sizes",
|
||||||
|
nargs="+",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_BATCH_SIZES)
|
||||||
|
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-act-order", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-k-full", nargs="+", type=int, default=[])
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
@@ -1,215 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import triton
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe import (fused_moe,
|
|
||||||
get_config_file_name)
|
|
||||||
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
|
||||||
|
|
||||||
|
|
||||||
def main(dtype: str):
|
|
||||||
method = fused_moe
|
|
||||||
for bs in [
|
|
||||||
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
|
||||||
2048, 3072, 4096
|
|
||||||
]:
|
|
||||||
run_grid(bs, method=method, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def run_grid(bs, method, dtype: str):
|
|
||||||
d_model = 4096
|
|
||||||
num_total_experts = 8
|
|
||||||
top_k = 2
|
|
||||||
tp_size = 2
|
|
||||||
model_intermediate_size = 14336
|
|
||||||
num_layers = 32
|
|
||||||
num_calls = 100
|
|
||||||
|
|
||||||
num_warmup_trials = 1
|
|
||||||
num_trials = 1
|
|
||||||
|
|
||||||
configs = []
|
|
||||||
|
|
||||||
for block_size_n in [32, 64, 128, 256]:
|
|
||||||
for block_size_m in [16, 32, 64, 128, 256]:
|
|
||||||
for block_size_k in [64, 128, 256]:
|
|
||||||
for group_size_m in [1, 16, 32, 64]:
|
|
||||||
for num_warps in [4, 8]:
|
|
||||||
for num_stages in [2, 3, 4, 5]:
|
|
||||||
configs.append({
|
|
||||||
"BLOCK_SIZE_M": block_size_m,
|
|
||||||
"BLOCK_SIZE_N": block_size_n,
|
|
||||||
"BLOCK_SIZE_K": block_size_k,
|
|
||||||
"GROUP_SIZE_M": group_size_m,
|
|
||||||
"num_warps": num_warps,
|
|
||||||
"num_stages": num_stages,
|
|
||||||
})
|
|
||||||
|
|
||||||
best_config = None
|
|
||||||
best_time_us = 1e20
|
|
||||||
|
|
||||||
print(f'{tp_size=} {bs=}')
|
|
||||||
|
|
||||||
for config in tqdm(configs):
|
|
||||||
# warmup
|
|
||||||
try:
|
|
||||||
for _ in range(num_warmup_trials):
|
|
||||||
run_timing(
|
|
||||||
num_calls=num_calls,
|
|
||||||
bs=bs,
|
|
||||||
d_model=d_model,
|
|
||||||
num_total_experts=num_total_experts,
|
|
||||||
top_k=top_k,
|
|
||||||
tp_size=tp_size,
|
|
||||||
model_intermediate_size=model_intermediate_size,
|
|
||||||
method=method,
|
|
||||||
config=config,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
except triton.runtime.autotuner.OutOfResources:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# trial
|
|
||||||
for _ in range(num_trials):
|
|
||||||
kernel_dur_ms = run_timing(
|
|
||||||
num_calls=num_calls,
|
|
||||||
bs=bs,
|
|
||||||
d_model=d_model,
|
|
||||||
num_total_experts=num_total_experts,
|
|
||||||
top_k=top_k,
|
|
||||||
tp_size=tp_size,
|
|
||||||
model_intermediate_size=model_intermediate_size,
|
|
||||||
method=method,
|
|
||||||
config=config,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
kernel_dur_us = 1000 * kernel_dur_ms
|
|
||||||
model_dur_ms = kernel_dur_ms * num_layers
|
|
||||||
|
|
||||||
if kernel_dur_us < best_time_us:
|
|
||||||
best_config = config
|
|
||||||
best_time_us = kernel_dur_us
|
|
||||||
|
|
||||||
tqdm.write(
|
|
||||||
f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}'
|
|
||||||
f' {bs=} {tp_size=} {top_k=} {num_total_experts=} '
|
|
||||||
f'{d_model=} {model_intermediate_size=} {num_layers=}')
|
|
||||||
|
|
||||||
print("best_time_us", best_time_us)
|
|
||||||
print("best_config", best_config)
|
|
||||||
|
|
||||||
# holds Dict[str, Dict[str, int]]
|
|
||||||
filename = get_config_file_name(num_total_experts,
|
|
||||||
model_intermediate_size // tp_size,
|
|
||||||
"float8" if dtype == "float8" else None)
|
|
||||||
print(f"writing config to file {filename}")
|
|
||||||
existing_content = {}
|
|
||||||
if os.path.exists(filename):
|
|
||||||
with open(filename, "r") as f:
|
|
||||||
existing_content = json.load(f)
|
|
||||||
existing_content[str(bs)] = best_config
|
|
||||||
with open(filename, "w") as f:
|
|
||||||
json.dump(existing_content, f, indent=4)
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
|
|
||||||
def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
|
|
||||||
top_k: int, tp_size: int, model_intermediate_size: int, method,
|
|
||||||
config, dtype: str) -> float:
|
|
||||||
shard_intermediate_size = model_intermediate_size // tp_size
|
|
||||||
|
|
||||||
hidden_states = torch.rand(
|
|
||||||
(bs, d_model),
|
|
||||||
device="cuda:0",
|
|
||||||
dtype=torch.float16,
|
|
||||||
)
|
|
||||||
|
|
||||||
w1 = torch.rand(
|
|
||||||
(num_total_experts, 2 * shard_intermediate_size, d_model),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
w2 = torch.rand(
|
|
||||||
(num_total_experts, d_model, shard_intermediate_size),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
w1_scale = None
|
|
||||||
w2_scale = None
|
|
||||||
a1_scale = None
|
|
||||||
a2_scale = None
|
|
||||||
|
|
||||||
if dtype == "float8":
|
|
||||||
w1 = w1.to(torch.float8_e4m3fn)
|
|
||||||
w2 = w2.to(torch.float8_e4m3fn)
|
|
||||||
w1_scale = torch.ones(num_total_experts,
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.float32)
|
|
||||||
w2_scale = torch.ones(num_total_experts,
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.float32)
|
|
||||||
a1_scale = torch.ones(1,
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.float32)
|
|
||||||
a2_scale = torch.ones(1,
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.float32)
|
|
||||||
|
|
||||||
gating_output = F.softmax(torch.rand(
|
|
||||||
(num_calls, bs, num_total_experts),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
),
|
|
||||||
dim=-1)
|
|
||||||
|
|
||||||
start_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
end_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
|
|
||||||
start_event.record()
|
|
||||||
for i in range(num_calls):
|
|
||||||
hidden_states = method(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
w1=w1,
|
|
||||||
w2=w2,
|
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
a1_scale=a1_scale,
|
|
||||||
a2_scale=a2_scale,
|
|
||||||
gating_output=gating_output[i],
|
|
||||||
topk=2,
|
|
||||||
renormalize=True,
|
|
||||||
inplace=True,
|
|
||||||
override_config=config,
|
|
||||||
use_fp8=dtype == "float8",
|
|
||||||
)
|
|
||||||
end_event.record()
|
|
||||||
end_event.synchronize()
|
|
||||||
|
|
||||||
dur_ms = start_event.elapsed_time(end_event) / num_calls
|
|
||||||
return dur_ms
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
prog='benchmark_mixtral_moe',
|
|
||||||
description='Benchmark and tune the fused_moe kernel',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--dtype',
|
|
||||||
type=str,
|
|
||||||
default='auto',
|
|
||||||
choices=['float8', 'float16'],
|
|
||||||
help='Data type used for fused_moe kernel computations',
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
sys.exit(main(args.dtype))
|
|
||||||
322
benchmarks/kernels/benchmark_moe.py
Normal file
322
benchmarks/kernels/benchmark_moe.py
Normal file
@@ -0,0 +1,322 @@
|
|||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
import ray
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
from ray.experimental.tqdm_ray import tqdm
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_config(
|
||||||
|
config: Dict[str, int],
|
||||||
|
num_tokens: int,
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8: bool,
|
||||||
|
num_iters: int = 100,
|
||||||
|
) -> float:
|
||||||
|
init_dtype = torch.float16 if use_fp8 else dtype
|
||||||
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
|
w1 = torch.randn(num_experts,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
dtype=init_dtype)
|
||||||
|
w2 = torch.randn(num_experts,
|
||||||
|
hidden_size,
|
||||||
|
shard_intermediate_size // 2,
|
||||||
|
dtype=init_dtype)
|
||||||
|
gating_output = torch.randn(num_iters,
|
||||||
|
num_tokens,
|
||||||
|
num_experts,
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
|
w1_scale = None
|
||||||
|
w2_scale = None
|
||||||
|
a1_scale = None
|
||||||
|
a2_scale = None
|
||||||
|
if use_fp8:
|
||||||
|
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
|
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
|
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
|
||||||
|
w1 = w1.to(torch.float8_e4m3fn)
|
||||||
|
w2 = w2.to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
|
||||||
|
|
||||||
|
def prepare(i: int):
|
||||||
|
input_gating.copy_(gating_output[i])
|
||||||
|
|
||||||
|
def run():
|
||||||
|
fused_moe(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
renormalize=True,
|
||||||
|
inplace=True,
|
||||||
|
override_config=config,
|
||||||
|
use_fp8=use_fp8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# JIT compilation & warmup
|
||||||
|
run()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Capture 10 invocations with CUDA graph
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(graph):
|
||||||
|
for _ in range(10):
|
||||||
|
run()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
for _ in range(5):
|
||||||
|
graph.replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
|
latencies = []
|
||||||
|
for i in range(num_iters):
|
||||||
|
prepare(i)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event.record()
|
||||||
|
graph.replay()
|
||||||
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
latencies.append(start_event.elapsed_time(end_event))
|
||||||
|
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||||
|
graph.reset()
|
||||||
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
def get_configs_compute_bound() -> List[Dict[str, int]]:
|
||||||
|
# Reduced search space for faster tuning.
|
||||||
|
# TODO(woosuk): Increase the search space and use a performance model to
|
||||||
|
# prune the search space.
|
||||||
|
configs = []
|
||||||
|
for num_stages in [2, 3, 4, 5]:
|
||||||
|
for block_m in [16, 32, 64, 128, 256]:
|
||||||
|
for block_k in [64, 128, 256]:
|
||||||
|
for block_n in [32, 64, 128, 256]:
|
||||||
|
for num_warps in [4, 8]:
|
||||||
|
for group_size in [1, 16, 32, 64]:
|
||||||
|
configs.append({
|
||||||
|
"BLOCK_SIZE_M": block_m,
|
||||||
|
"BLOCK_SIZE_N": block_n,
|
||||||
|
"BLOCK_SIZE_K": block_k,
|
||||||
|
"GROUP_SIZE_M": group_size,
|
||||||
|
"num_warps": num_warps,
|
||||||
|
"num_stages": num_stages,
|
||||||
|
})
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote(num_gpus=1)
|
||||||
|
class BenchmarkWorker:
|
||||||
|
|
||||||
|
def __init__(self, seed: int) -> None:
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def benchmark(
|
||||||
|
self,
|
||||||
|
num_tokens: int,
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8: bool,
|
||||||
|
) -> Tuple[Dict[str, int], float]:
|
||||||
|
torch.cuda.manual_seed_all(self.seed)
|
||||||
|
|
||||||
|
dtype_str = "float8" if use_fp8 else None
|
||||||
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||||
|
# is the intermediate size after silu_and_mul.
|
||||||
|
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
|
||||||
|
dtype_str)
|
||||||
|
if op_config is None:
|
||||||
|
config = get_default_config(num_tokens, num_experts,
|
||||||
|
shard_intermediate_size, hidden_size,
|
||||||
|
topk, dtype_str)
|
||||||
|
else:
|
||||||
|
config = op_config[min(op_config.keys(),
|
||||||
|
key=lambda x: abs(x - num_tokens))]
|
||||||
|
kernel_time = benchmark_config(config, num_tokens, num_experts,
|
||||||
|
shard_intermediate_size, hidden_size,
|
||||||
|
topk, dtype, use_fp8)
|
||||||
|
return config, kernel_time
|
||||||
|
|
||||||
|
def tune(
|
||||||
|
self,
|
||||||
|
num_tokens: int,
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8: bool,
|
||||||
|
search_space: List[Dict[str, int]],
|
||||||
|
) -> Dict[str, int]:
|
||||||
|
best_config = None
|
||||||
|
best_time = float("inf")
|
||||||
|
for config in tqdm(search_space):
|
||||||
|
try:
|
||||||
|
kernel_time = benchmark_config(config,
|
||||||
|
num_tokens,
|
||||||
|
num_experts,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8,
|
||||||
|
num_iters=10)
|
||||||
|
except triton.runtime.autotuner.OutOfResources:
|
||||||
|
# Some configurations may be invalid and fail to compile.
|
||||||
|
continue
|
||||||
|
|
||||||
|
if kernel_time < best_time:
|
||||||
|
best_time = kernel_time
|
||||||
|
best_config = config
|
||||||
|
now = datetime.now()
|
||||||
|
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
|
||||||
|
return best_config
|
||||||
|
|
||||||
|
|
||||||
|
def sort_config(config: Dict[str, int]) -> Dict[str, int]:
|
||||||
|
return {
|
||||||
|
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
|
||||||
|
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
|
||||||
|
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
|
||||||
|
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
|
||||||
|
"num_warps": config["num_warps"],
|
||||||
|
"num_stages": config["num_stages"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def save_configs(
|
||||||
|
configs: Dict[int, Dict[str, int]],
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8: bool,
|
||||||
|
) -> None:
|
||||||
|
dtype_str = "float8" if use_fp8 else None
|
||||||
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||||
|
# is the intermediate size after silu_and_mul.
|
||||||
|
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
|
||||||
|
dtype_str)
|
||||||
|
print(f"Writing best config to {filename}...")
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
json.dump(configs, f, indent=4)
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace):
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(args.model)
|
||||||
|
if config.architectures[0] == "DbrxForCausalLM":
|
||||||
|
E = config.ffn_config.moe_num_experts
|
||||||
|
topk = config.ffn_config.moe_top_k
|
||||||
|
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
else:
|
||||||
|
# Default: Mixtral.
|
||||||
|
E = config.num_local_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
dtype = config.torch_dtype
|
||||||
|
use_fp8 = args.dtype == "fp8"
|
||||||
|
|
||||||
|
if args.batch_size is None:
|
||||||
|
batch_sizes = [
|
||||||
|
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
||||||
|
2048, 3072, 4096
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
batch_sizes = [args.batch_size]
|
||||||
|
|
||||||
|
ray.init()
|
||||||
|
num_gpus = int(ray.available_resources()["GPU"])
|
||||||
|
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
||||||
|
|
||||||
|
def _distribute(method: str, inputs: List[Any]) -> List[Any]:
|
||||||
|
outputs = []
|
||||||
|
worker_idx = 0
|
||||||
|
for input_args in inputs:
|
||||||
|
worker = workers[worker_idx]
|
||||||
|
worker_method = getattr(worker, method)
|
||||||
|
output = worker_method.remote(*input_args)
|
||||||
|
outputs.append(output)
|
||||||
|
worker_idx = (worker_idx + 1) % num_gpus
|
||||||
|
return ray.get(outputs)
|
||||||
|
|
||||||
|
if args.tune:
|
||||||
|
search_space = get_configs_compute_bound()
|
||||||
|
print(f"Start tuning over {len(search_space)} configurations...")
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
configs = _distribute(
|
||||||
|
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
|
||||||
|
topk, dtype, use_fp8, search_space)
|
||||||
|
for batch_size in batch_sizes])
|
||||||
|
best_configs = {
|
||||||
|
M: sort_config(config)
|
||||||
|
for M, config in zip(batch_sizes, configs)
|
||||||
|
}
|
||||||
|
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
|
||||||
|
topk, dtype, use_fp8)
|
||||||
|
end = time.time()
|
||||||
|
print(f"Tuning took {end - start:.2f} seconds")
|
||||||
|
else:
|
||||||
|
outputs = _distribute("benchmark",
|
||||||
|
[(batch_size, E, shard_intermediate_size,
|
||||||
|
hidden_size, topk, dtype, use_fp8)
|
||||||
|
for batch_size in batch_sizes])
|
||||||
|
|
||||||
|
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||||
|
print(f"Batch size: {batch_size}, config: {config}")
|
||||||
|
print(f"Kernel time: {kernel_time:.2f} us")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model",
|
||||||
|
type=str,
|
||||||
|
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||||
|
parser.add_argument("--tp-size", "-tp", type=int, default=2)
|
||||||
|
parser.add_argument("--dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["auto", "fp8"],
|
||||||
|
default="auto")
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument("--batch-size", type=int, required=False)
|
||||||
|
parser.add_argument("--tune", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
@@ -170,7 +170,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--num-kv-heads", type=int, default=8)
|
parser.add_argument("--num-kv-heads", type=int, default=8)
|
||||||
parser.add_argument("--head-size",
|
parser.add_argument("--head-size",
|
||||||
type=int,
|
type=int,
|
||||||
choices=[64, 80, 96, 112, 128, 256],
|
choices=[64, 80, 96, 112, 128, 192, 256],
|
||||||
default=128)
|
default=128)
|
||||||
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
||||||
parser.add_argument("--use-alibi", action="store_true")
|
parser.add_argument("--use-alibi", action="store_true")
|
||||||
@@ -183,13 +183,11 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kv-cache-dtype",
|
"--kv-cache-dtype",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["auto", "fp8"],
|
choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
|
||||||
default="auto",
|
default="auto",
|
||||||
help=
|
help="Data type for kv cache storage. If 'auto', will use model "
|
||||||
'Data type for kv cache storage. If "auto", will use model data type. '
|
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
|
||||||
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
|
||||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
|
||||||
'common inference criteria.')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--num-heads", type=int, default=8)
|
parser.add_argument("--num-heads", type=int, default=8)
|
||||||
parser.add_argument("--head-size",
|
parser.add_argument("--head-size",
|
||||||
type=int,
|
type=int,
|
||||||
choices=[64, 80, 96, 112, 128, 256],
|
choices=[64, 80, 96, 112, 128, 192, 256],
|
||||||
default=128)
|
default=128)
|
||||||
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
|
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
|
||||||
parser.add_argument("--dtype",
|
parser.add_argument("--dtype",
|
||||||
|
|||||||
75
benchmarks/kernels/benchmark_shapes.py
Normal file
75
benchmarks/kernels/benchmark_shapes.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
WEIGHT_SHAPES = {
|
||||||
|
"ideal": [[4 * 256 * 32, 256 * 32]],
|
||||||
|
"mistralai/Mistral-7B-v0.1/TP1": [
|
||||||
|
[4096, 6144],
|
||||||
|
[4096, 4096],
|
||||||
|
[4096, 28672],
|
||||||
|
[14336, 4096],
|
||||||
|
],
|
||||||
|
"mistralai/Mistral-7B-v0.1/TP2": [
|
||||||
|
[4096, 3072],
|
||||||
|
[2048, 4096],
|
||||||
|
[4096, 14336],
|
||||||
|
[7168, 4096],
|
||||||
|
],
|
||||||
|
"mistralai/Mistral-7B-v0.1/TP4": [
|
||||||
|
[4096, 1536],
|
||||||
|
[1024, 4096],
|
||||||
|
[4096, 7168],
|
||||||
|
[3584, 4096],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-7b-hf/TP1": [
|
||||||
|
[4096, 12288],
|
||||||
|
[4096, 4096],
|
||||||
|
[4096, 22016],
|
||||||
|
[11008, 4096],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-7b-hf/TP2": [
|
||||||
|
[4096, 6144],
|
||||||
|
[2048, 4096],
|
||||||
|
[4096, 11008],
|
||||||
|
[5504, 4096],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-7b-hf/TP4": [
|
||||||
|
[4096, 3072],
|
||||||
|
[1024, 4096],
|
||||||
|
[4096, 5504],
|
||||||
|
[2752, 4096],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-13b-hf/TP1": [
|
||||||
|
[5120, 15360],
|
||||||
|
[5120, 5120],
|
||||||
|
[5120, 27648],
|
||||||
|
[13824, 5120],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-13b-hf/TP2": [
|
||||||
|
[5120, 7680],
|
||||||
|
[2560, 5120],
|
||||||
|
[5120, 13824],
|
||||||
|
[6912, 5120],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-13b-hf/TP4": [
|
||||||
|
[5120, 3840],
|
||||||
|
[1280, 5120],
|
||||||
|
[5120, 6912],
|
||||||
|
[3456, 5120],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-70b-hf/TP1": [
|
||||||
|
[8192, 10240],
|
||||||
|
[8192, 8192],
|
||||||
|
[8192, 57344],
|
||||||
|
[28672, 8192],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-70b-hf/TP2": [
|
||||||
|
[8192, 5120],
|
||||||
|
[4096, 8192],
|
||||||
|
[8192, 28672],
|
||||||
|
[14336, 8192],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-70b-hf/TP4": [
|
||||||
|
[8192, 2560],
|
||||||
|
[2048, 8192],
|
||||||
|
[8192, 14336],
|
||||||
|
[7168, 8192],
|
||||||
|
],
|
||||||
|
}
|
||||||
@@ -4,7 +4,7 @@ PORT=8000
|
|||||||
MODEL=$1
|
MODEL=$1
|
||||||
TOKENS=$2
|
TOKENS=$2
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -p $PORT:80 \
|
docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \
|
||||||
-v $PWD/data:/data \
|
-v $PWD/data:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:1.4.0 \
|
ghcr.io/huggingface/text-generation-inference:1.4.0 \
|
||||||
--model-id $MODEL \
|
--model-id $MODEL \
|
||||||
|
|||||||
63
benchmarks/overheads/benchmark_hashing.py
Normal file
63
benchmarks/overheads/benchmark_hashing.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import argparse
|
||||||
|
import cProfile
|
||||||
|
import pstats
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
# A very long prompt, total number of tokens is about 15k.
|
||||||
|
LONG_PROMPT = ["You are an expert in large language models, aren't you?"
|
||||||
|
] * 1000
|
||||||
|
LONG_PROMPT = ' '.join(LONG_PROMPT)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
llm = LLM(
|
||||||
|
model=args.model,
|
||||||
|
enforce_eager=True,
|
||||||
|
enable_prefix_caching=True,
|
||||||
|
tensor_parallel_size=args.tensor_parallel_size,
|
||||||
|
use_v2_block_manager=args.use_v2_block_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
|
||||||
|
profiler = cProfile.Profile()
|
||||||
|
|
||||||
|
print("------warm up------")
|
||||||
|
for i in range(3):
|
||||||
|
output = llm.generate(LONG_PROMPT, sampling_params)
|
||||||
|
print(output[0].outputs[0].text)
|
||||||
|
|
||||||
|
print("------start generating------")
|
||||||
|
for i in range(3):
|
||||||
|
profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)',
|
||||||
|
globals(), locals())
|
||||||
|
|
||||||
|
# analyze the runtime of hashing function
|
||||||
|
stats = pstats.Stats(profiler)
|
||||||
|
stats.sort_stats('cumulative')
|
||||||
|
total_time = 0
|
||||||
|
total_calls = 0
|
||||||
|
for func in stats.stats:
|
||||||
|
if 'hash_of_block' in func[2]:
|
||||||
|
total_time = stats.stats[func][3]
|
||||||
|
total_calls = stats.stats[func][0]
|
||||||
|
percentage = (total_time / stats.total_tt) * 100
|
||||||
|
print(f"Hashing took {total_time:.2f} seconds,"
|
||||||
|
f"{percentage:.2f}% of the total runtime.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Benchmark the performance of hashing function in'
|
||||||
|
'automatic prefix caching.')
|
||||||
|
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k')
|
||||||
|
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||||
|
parser.add_argument('--output-len', type=int, default=10)
|
||||||
|
parser.add_argument('--enable-prefix-caching',
|
||||||
|
action='store_true',
|
||||||
|
help='enable prefix caching')
|
||||||
|
parser.add_argument('--use-v2-block-manager',
|
||||||
|
action='store_true',
|
||||||
|
help='Use BlockSpaceMangerV2')
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
@@ -73,7 +73,7 @@ set(VLLM_EXT_SRC
|
|||||||
"csrc/cpu/cache.cpp"
|
"csrc/cpu/cache.cpp"
|
||||||
"csrc/cpu/layernorm.cpp"
|
"csrc/cpu/layernorm.cpp"
|
||||||
"csrc/cpu/pos_encoding.cpp"
|
"csrc/cpu/pos_encoding.cpp"
|
||||||
"csrc/cpu/pybind.cpp")
|
"csrc/cpu/torch_bindings.cpp")
|
||||||
|
|
||||||
define_gpu_extension_target(
|
define_gpu_extension_target(
|
||||||
_C
|
_C
|
||||||
@@ -81,10 +81,10 @@ define_gpu_extension_target(
|
|||||||
LANGUAGE CXX
|
LANGUAGE CXX
|
||||||
SOURCES ${VLLM_EXT_SRC}
|
SOURCES ${VLLM_EXT_SRC}
|
||||||
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
|
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
|
||||||
|
USE_SABI 3
|
||||||
WITH_SOABI
|
WITH_SOABI
|
||||||
)
|
)
|
||||||
|
|
||||||
add_custom_target(default)
|
add_custom_target(default)
|
||||||
message(STATUS "Enabling C extension.")
|
message(STATUS "Enabling C extension.")
|
||||||
add_dependencies(default _C)
|
add_dependencies(default _C)
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
|
macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
|
||||||
file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
|
file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
|
||||||
set(Python_EXECUTABLE ${EXECUTABLE})
|
set(Python_EXECUTABLE ${EXECUTABLE})
|
||||||
find_package(Python COMPONENTS Interpreter Development.Module)
|
find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule)
|
||||||
if (NOT Python_FOUND)
|
if (NOT Python_FOUND)
|
||||||
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
|
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
|
||||||
endif()
|
endif()
|
||||||
@@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
|
|||||||
"Failed to determine torch nvcc compiler flags")
|
"Failed to determine torch nvcc compiler flags")
|
||||||
|
|
||||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
|
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
|
||||||
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
|
list(APPEND GPU_FLAGS "-DENABLE_FP8")
|
||||||
endif()
|
endif()
|
||||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
|
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
|
||||||
list(REMOVE_ITEM GPU_FLAGS
|
list(REMOVE_ITEM GPU_FLAGS
|
||||||
@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
|
|||||||
|
|
||||||
list(APPEND GPU_FLAGS
|
list(APPEND GPU_FLAGS
|
||||||
"-DUSE_ROCM"
|
"-DUSE_ROCM"
|
||||||
"-DENABLE_FP8_E4M3"
|
"-DENABLE_FP8"
|
||||||
"-U__HIP_NO_HALF_CONVERSIONS__"
|
"-U__HIP_NO_HALF_CONVERSIONS__"
|
||||||
"-U__HIP_NO_HALF_OPERATORS__"
|
"-U__HIP_NO_HALF_OPERATORS__"
|
||||||
"-fno-gpu-rdc")
|
"-fno-gpu-rdc")
|
||||||
@@ -294,6 +294,7 @@ endmacro()
|
|||||||
# INCLUDE_DIRECTORIES <dirs> - Extra include directories.
|
# INCLUDE_DIRECTORIES <dirs> - Extra include directories.
|
||||||
# LIBRARIES <libraries> - Extra link libraries.
|
# LIBRARIES <libraries> - Extra link libraries.
|
||||||
# WITH_SOABI - Generate library with python SOABI suffix name.
|
# WITH_SOABI - Generate library with python SOABI suffix name.
|
||||||
|
# USE_SABI <version> - Use python stable api <version>
|
||||||
#
|
#
|
||||||
# Note: optimization level/debug info is set via cmake build type.
|
# Note: optimization level/debug info is set via cmake build type.
|
||||||
#
|
#
|
||||||
@@ -301,7 +302,7 @@ function (define_gpu_extension_target GPU_MOD_NAME)
|
|||||||
cmake_parse_arguments(PARSE_ARGV 1
|
cmake_parse_arguments(PARSE_ARGV 1
|
||||||
GPU
|
GPU
|
||||||
"WITH_SOABI"
|
"WITH_SOABI"
|
||||||
"DESTINATION;LANGUAGE"
|
"DESTINATION;LANGUAGE;USE_SABI"
|
||||||
"SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
|
"SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
|
||||||
|
|
||||||
# Add hipify preprocessing step when building with HIP/ROCm.
|
# Add hipify preprocessing step when building with HIP/ROCm.
|
||||||
@@ -315,7 +316,11 @@ function (define_gpu_extension_target GPU_MOD_NAME)
|
|||||||
set(GPU_WITH_SOABI)
|
set(GPU_WITH_SOABI)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
Python_add_library(${GPU_MOD_NAME} MODULE "${GPU_SOURCES}" ${GPU_WITH_SOABI})
|
if (GPU_USE_SABI)
|
||||||
|
Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}")
|
||||||
|
else()
|
||||||
|
Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}")
|
||||||
|
endif()
|
||||||
|
|
||||||
if (GPU_LANGUAGE STREQUAL "HIP")
|
if (GPU_LANGUAGE STREQUAL "HIP")
|
||||||
# Make this target dependent on the hipify preprocessor step.
|
# Make this target dependent on the hipify preprocessor step.
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ DEFAULT_CONDA_PATTERNS = {
|
|||||||
"triton",
|
"triton",
|
||||||
"optree",
|
"optree",
|
||||||
"nccl",
|
"nccl",
|
||||||
|
"transformers",
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFAULT_PIP_PATTERNS = {
|
DEFAULT_PIP_PATTERNS = {
|
||||||
@@ -75,6 +76,7 @@ DEFAULT_PIP_PATTERNS = {
|
|||||||
"optree",
|
"optree",
|
||||||
"onnx",
|
"onnx",
|
||||||
"nccl",
|
"nccl",
|
||||||
|
"transformers",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -601,6 +603,11 @@ Versions of relevant libraries:
|
|||||||
{conda_packages}
|
{conda_packages}
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
# both the above code and the following code use `strip()` to
|
||||||
|
# remove leading/trailing whitespaces, so we need to add a newline
|
||||||
|
# in between to separate the two sections
|
||||||
|
env_info_fmt += "\n"
|
||||||
|
|
||||||
env_info_fmt += """
|
env_info_fmt += """
|
||||||
ROCM Version: {rocm_version}
|
ROCM Version: {rocm_version}
|
||||||
Neuron SDK Version: {neuron_sdk_version}
|
Neuron SDK Version: {neuron_sdk_version}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
@@ -63,31 +63,25 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||||
input.scalar_type(), \
|
input.scalar_type(), "act_and_mul_kernel", [&] { \
|
||||||
"act_and_mul_kernel", \
|
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
|
||||||
[&] { \
|
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
||||||
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
input.data_ptr<scalar_t>(), d); \
|
||||||
out.data_ptr<scalar_t>(), \
|
|
||||||
input.data_ptr<scalar_t>(), \
|
|
||||||
d); \
|
|
||||||
});
|
});
|
||||||
|
|
||||||
void silu_and_mul(
|
void silu_and_mul(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& out, // [..., d]
|
|
||||||
torch::Tensor& input) // [..., 2 * d]
|
torch::Tensor& input) // [..., 2 * d]
|
||||||
{
|
{
|
||||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
||||||
}
|
}
|
||||||
|
|
||||||
void gelu_and_mul(
|
void gelu_and_mul(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& out, // [..., d]
|
|
||||||
torch::Tensor& input) // [..., 2 * d]
|
torch::Tensor& input) // [..., 2 * d]
|
||||||
{
|
{
|
||||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
|
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
|
||||||
}
|
}
|
||||||
|
|
||||||
void gelu_tanh_and_mul(
|
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& out, // [..., d]
|
|
||||||
torch::Tensor& input) // [..., 2 * d]
|
torch::Tensor& input) // [..., 2 * d]
|
||||||
{
|
{
|
||||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
|
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
|
||||||
@@ -118,14 +112,10 @@ __global__ void activation_kernel(
|
|||||||
dim3 block(std::min(d, 1024)); \
|
dim3 block(std::min(d, 1024)); \
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
|
||||||
input.scalar_type(), \
|
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
|
||||||
"activation_kernel", \
|
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
||||||
[&] { \
|
input.data_ptr<scalar_t>(), d); \
|
||||||
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
|
||||||
out.data_ptr<scalar_t>(), \
|
|
||||||
input.data_ptr<scalar_t>(), \
|
|
||||||
d); \
|
|
||||||
});
|
});
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
@@ -140,21 +130,20 @@ __device__ __forceinline__ T gelu_new_kernel(const T& x) {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
|
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
|
||||||
const float f = (float)x;
|
const float f = (float)x;
|
||||||
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
|
const T t =
|
||||||
|
(T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
|
||||||
return ((T)0.5) * x * (((T)1.0) + t);
|
return ((T)0.5) * x * (((T)1.0) + t);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void gelu_new(
|
void gelu_new(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& out, // [..., d]
|
|
||||||
torch::Tensor& input) // [..., d]
|
torch::Tensor& input) // [..., d]
|
||||||
{
|
{
|
||||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
||||||
}
|
}
|
||||||
|
|
||||||
void gelu_fast(
|
void gelu_fast(torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& out, // [..., d]
|
|
||||||
torch::Tensor& input) // [..., d]
|
torch::Tensor& input) // [..., d]
|
||||||
{
|
{
|
||||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
/*
|
/*
|
||||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
* Adapted from
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||||
* Copyright (c) 2023, The vLLM team.
|
* Copyright (c) 2023, The vLLM team.
|
||||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,6 @@
|
|||||||
/*
|
/*
|
||||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
* Adapted from
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||||
* Copyright (c) 2023, The vLLM team.
|
* Copyright (c) 2023, The vLLM team.
|
||||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
/*
|
/*
|
||||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
* Adapted from
|
||||||
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||||
|
* and
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||||
* Copyright (c) 2023, The vLLM team.
|
* Copyright (c) 2023, The vLLM team.
|
||||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
@@ -280,7 +282,8 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Vector fused multiply-add.
|
// Vector fused multiply-add.
|
||||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
|
||||||
|
__nv_bfloat162 c) {
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
assert(false);
|
assert(false);
|
||||||
#else
|
#else
|
||||||
@@ -288,7 +291,8 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bf
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
|
||||||
|
__nv_bfloat162 c) {
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
assert(false);
|
assert(false);
|
||||||
#else
|
#else
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
/*
|
/*
|
||||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
* Adapted from
|
||||||
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||||
|
* and
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||||
* Copyright (c) 2023, The vLLM team.
|
* Copyright (c) 2023, The vLLM team.
|
||||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
@@ -130,7 +132,9 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
|
|||||||
} tmp;
|
} tmp;
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
|
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
|
||||||
|
: "=r"(tmp.u32)
|
||||||
|
: "f"(f.y), "f"(f.x));
|
||||||
#else
|
#else
|
||||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
||||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
||||||
@@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
|
|||||||
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
|
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
|
||||||
uint32_t d;
|
uint32_t d;
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(d)
|
||||||
|
: "r"(a), "r"(b), "r"(c));
|
||||||
#else
|
#else
|
||||||
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
|
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n"
|
||||||
|
: "=v"(d)
|
||||||
|
: "v"(a), "v"(b), "v"(c));
|
||||||
#endif
|
#endif
|
||||||
return d;
|
return d;
|
||||||
}
|
}
|
||||||
@@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// From float16 to float32.
|
// From float16 to float32.
|
||||||
inline __device__ float to_float(uint16_t u) {
|
inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
|
||||||
return half_to_float(u);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ float2 to_float(uint32_t u) {
|
inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
|
||||||
return half2_to_float2(u);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ Float4_ to_float(uint2 u) {
|
inline __device__ Float4_ to_float(uint2 u) {
|
||||||
Float4_ tmp;
|
Float4_ tmp;
|
||||||
@@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Zero-out a variable.
|
// Zero-out a variable.
|
||||||
inline __device__ void zero(uint16_t& dst) {
|
inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
|
||||||
dst = uint16_t(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
/*
|
/*
|
||||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
* Adapted from
|
||||||
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||||
|
* and
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||||
* Copyright (c) 2023, The vLLM team.
|
* Copyright (c) 2023, The vLLM team.
|
||||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
@@ -66,9 +68,7 @@ struct FloatVec<float4> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Vector addition.
|
// Vector addition.
|
||||||
inline __device__ float add(float a, float b) {
|
inline __device__ float add(float a, float b) { return a + b; }
|
||||||
return a + b;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ float2 add(float2 a, float2 b) {
|
inline __device__ float2 add(float2 a, float2 b) {
|
||||||
float2 c;
|
float2 c;
|
||||||
@@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Vector fused multiply-add.
|
// Vector fused multiply-add.
|
||||||
inline __device__ float fma(float a, float b, float c) {
|
inline __device__ float fma(float a, float b, float c) { return a * b + c; }
|
||||||
return a * b + c;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
|
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
|
||||||
float2 d;
|
float2 d;
|
||||||
@@ -208,9 +206,7 @@ inline __device__ float sum(Float8_ v) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Vector dot product.
|
// Vector dot product.
|
||||||
inline __device__ float dot(float a, float b) {
|
inline __device__ float dot(float a, float b) { return a * b; }
|
||||||
return a * b;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ float dot(float2 a, float2 b) {
|
inline __device__ float dot(float2 a, float2 b) {
|
||||||
float2 c = mul<float2, float2, float2>(a, b);
|
float2 c = mul<float2, float2, float2>(a, b);
|
||||||
@@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// From float to float.
|
// From float to float.
|
||||||
inline __device__ void from_float(float& dst, float src) {
|
inline __device__ void from_float(float& dst, float src) { dst = src; }
|
||||||
dst = src;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ void from_float(float2& dst, float2 src) {
|
inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
|
||||||
dst = src;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ void from_float(float4& dst, float4 src) {
|
inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
|
||||||
dst = src;
|
|
||||||
}
|
|
||||||
|
|
||||||
// From float to float.
|
// From float to float.
|
||||||
inline __device__ float to_float(float u) {
|
inline __device__ float to_float(float u) { return u; }
|
||||||
return u;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ float2 to_float(float2 u) {
|
inline __device__ float2 to_float(float2 u) { return u; }
|
||||||
return u;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ float4 to_float(float4 u) {
|
inline __device__ float4 to_float(float4 u) { return u; }
|
||||||
return u;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ Float4_ to_float(Float4_ u) {
|
inline __device__ Float4_ to_float(Float4_ u) { return u; }
|
||||||
return u;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __device__ Float8_ to_float(Float8_ u) {
|
inline __device__ Float8_ to_float(Float8_ u) { return u; }
|
||||||
return u;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Zero-out a variable.
|
// Zero-out a variable.
|
||||||
inline __device__ void zero(float& dst) {
|
inline __device__ void zero(float& dst) { dst = 0.f; }
|
||||||
dst = 0.f;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|||||||
@@ -3,14 +3,21 @@
|
|||||||
#include "attention_generic.cuh"
|
#include "attention_generic.cuh"
|
||||||
|
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#ifdef ENABLE_FP8_E5M2
|
#ifdef ENABLE_FP8
|
||||||
|
#ifndef USE_ROCM
|
||||||
#include <cuda_fp8.h>
|
#include <cuda_fp8.h>
|
||||||
#endif
|
#endif // USE_ROCM
|
||||||
|
#endif // ENABLE_FP8
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
|
|
||||||
// fp8 vector types for quantization of kv cache
|
|
||||||
|
|
||||||
|
enum class Fp8KVCacheDataType {
|
||||||
|
kAuto = 0,
|
||||||
|
kFp8E4M3 = 1,
|
||||||
|
kFp8E5M2 = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
// fp8 vector types for quantization of kv cache
|
||||||
template <>
|
template <>
|
||||||
struct Vec<uint8_t, 1> {
|
struct Vec<uint8_t, 1> {
|
||||||
using Type = uint8_t;
|
using Type = uint8_t;
|
||||||
@@ -30,6 +37,5 @@ template<>
|
|||||||
struct Vec<uint8_t, 8> {
|
struct Vec<uint8_t, 8> {
|
||||||
using Type = uint2;
|
using Type = uint2;
|
||||||
};
|
};
|
||||||
#endif // ENABLE_FP8_E5M2
|
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|||||||
36
csrc/cache.h
36
csrc/cache.h
@@ -1,38 +1,32 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
void swap_blocks(
|
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||||
torch::Tensor& src,
|
const torch::Tensor& block_mapping);
|
||||||
torch::Tensor& dst,
|
|
||||||
const std::map<int64_t, int64_t>& block_mapping);
|
|
||||||
|
|
||||||
void copy_blocks(
|
// Note: the key_caches and value_caches vectors are constant but
|
||||||
std::vector<torch::Tensor>& key_caches,
|
// not the Tensors they contain. The vectors need to be const refs
|
||||||
std::vector<torch::Tensor>& value_caches,
|
// in order to satisfy pytorch's C++ operator registration code.
|
||||||
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
|
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||||
|
std::vector<torch::Tensor> const& value_caches,
|
||||||
|
const torch::Tensor& block_mapping);
|
||||||
|
|
||||||
void reshape_and_cache(
|
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||||
torch::Tensor& key,
|
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||||
torch::Tensor& value,
|
|
||||||
torch::Tensor& key_cache,
|
|
||||||
torch::Tensor& value_cache,
|
|
||||||
torch::Tensor& slot_mapping,
|
torch::Tensor& slot_mapping,
|
||||||
const std::string& kv_cache_dtype,
|
const std::string& kv_cache_dtype,
|
||||||
const float kv_scale);
|
const double kv_scale);
|
||||||
|
|
||||||
void reshape_and_cache_flash(
|
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
||||||
torch::Tensor& key,
|
|
||||||
torch::Tensor& value,
|
|
||||||
torch::Tensor& key_cache,
|
torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache,
|
torch::Tensor& value_cache,
|
||||||
torch::Tensor& slot_mapping,
|
torch::Tensor& slot_mapping,
|
||||||
const std::string& kv_cache_dtype);
|
const std::string& kv_cache_dtype);
|
||||||
|
|
||||||
// Just for unittest
|
// Just for unittest
|
||||||
void convert_fp8(
|
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||||
torch::Tensor& src_cache,
|
const double scale, const std::string& kv_cache_dtype);
|
||||||
torch::Tensor& dst_cache);
|
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
#include "cuda_compat.h"
|
#include "cuda_compat.h"
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
#if defined(ENABLE_FP8_E5M2)
|
|
||||||
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
#ifdef USE_ROCM
|
||||||
#elif defined(ENABLE_FP8_E4M3)
|
#include "quantization/fp8/amd/quant_utils.cuh"
|
||||||
#include "quantization/fp8/amd_detail/quant_utils.cuh"
|
#else
|
||||||
|
#include "quantization/fp8/nvidia/quant_utils.cuh"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
@@ -20,16 +21,13 @@
|
|||||||
typedef __hip_bfloat16 __nv_bfloat16;
|
typedef __hip_bfloat16 __nv_bfloat16;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void swap_blocks(
|
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||||
torch::Tensor& src,
|
const torch::Tensor& block_mapping) {
|
||||||
torch::Tensor& dst,
|
|
||||||
const std::map<int64_t, int64_t>& block_mapping) {
|
|
||||||
torch::Device src_device = src.device();
|
torch::Device src_device = src.device();
|
||||||
torch::Device dst_device = dst.device();
|
torch::Device dst_device = dst.device();
|
||||||
cudaMemcpyKind memcpy_type;
|
cudaMemcpyKind memcpy_type;
|
||||||
if (src_device.is_cuda() && dst_device.is_cuda()) {
|
if (src_device.is_cuda() && dst_device.is_cuda()) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(src_device.index() == dst_device.index(),
|
||||||
src_device.index() == dst_device.index(),
|
|
||||||
"src and dst must be on the same GPU");
|
"src and dst must be on the same GPU");
|
||||||
memcpy_type = cudaMemcpyDeviceToDevice;
|
memcpy_type = cudaMemcpyDeviceToDevice;
|
||||||
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
|
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
|
||||||
@@ -40,24 +38,27 @@ void swap_blocks(
|
|||||||
TORCH_CHECK(false, "Invalid device combination");
|
TORCH_CHECK(false, "Invalid device combination");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE(youkaichao): keep in mind that `block_mapping` should be
|
||||||
|
// a cpu tensor, otherwise every `item` call will require a gpu-cpu
|
||||||
|
// synchronization.
|
||||||
|
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
|
||||||
|
|
||||||
char* src_ptr = static_cast<char*>(src.data_ptr());
|
char* src_ptr = static_cast<char*>(src.data_ptr());
|
||||||
char* dst_ptr = static_cast<char*>(dst.data_ptr());
|
char* dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||||
|
|
||||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
|
const at::cuda::OptionalCUDAGuard device_guard(
|
||||||
|
src_device.is_cuda() ? src_device : dst_device);
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
||||||
for (const auto& pair : block_mapping) {
|
const int64_t num_blocks = block_mapping.size(0);
|
||||||
int64_t src_block_number = pair.first;
|
for (size_t i = 0; i < num_blocks; i++) {
|
||||||
int64_t dst_block_number = pair.second;
|
int64_t src_block_number = block_mapping[i][0].item<int64_t>();
|
||||||
|
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
|
||||||
int64_t src_offset = src_block_number * block_size_in_bytes;
|
int64_t src_offset = src_block_number * block_size_in_bytes;
|
||||||
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
||||||
cudaMemcpyAsync(
|
cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
|
||||||
dst_ptr + dst_offset,
|
block_size_in_bytes, memcpy_type, stream);
|
||||||
src_ptr + src_offset,
|
|
||||||
block_size_in_bytes,
|
|
||||||
memcpy_type,
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,8 +66,7 @@ namespace vllm {
|
|||||||
|
|
||||||
// Grid: (num_layers, num_pairs)
|
// Grid: (num_layers, num_pairs)
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void copy_blocks_kernel(
|
__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
|
||||||
int64_t* key_cache_ptrs,
|
|
||||||
int64_t* value_cache_ptrs,
|
int64_t* value_cache_ptrs,
|
||||||
const int64_t* __restrict__ block_mapping,
|
const int64_t* __restrict__ block_mapping,
|
||||||
const int numel_per_block) {
|
const int numel_per_block) {
|
||||||
@@ -74,7 +74,8 @@ __global__ void copy_blocks_kernel(
|
|||||||
const int pair_idx = blockIdx.y;
|
const int pair_idx = blockIdx.y;
|
||||||
|
|
||||||
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
|
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
|
||||||
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
scalar_t* value_cache =
|
||||||
|
reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
||||||
int64_t src_block_number = block_mapping[2 * pair_idx];
|
int64_t src_block_number = block_mapping[2 * pair_idx];
|
||||||
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
|
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
|
||||||
|
|
||||||
@@ -94,10 +95,12 @@ __global__ void copy_blocks_kernel(
|
|||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void copy_blocks(
|
// Note: the key_caches and value_caches vectors are constant but
|
||||||
std::vector<torch::Tensor>& key_caches,
|
// not the Tensors they contain. The vectors need to be const refs
|
||||||
std::vector<torch::Tensor>& value_caches,
|
// in order to satisfy pytorch's C++ operator registration code.
|
||||||
const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
|
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||||
|
std::vector<torch::Tensor> const& value_caches,
|
||||||
|
const torch::Tensor& block_mapping) {
|
||||||
int num_layers = key_caches.size();
|
int num_layers = key_caches.size();
|
||||||
TORCH_CHECK(num_layers == value_caches.size());
|
TORCH_CHECK(num_layers == value_caches.size());
|
||||||
if (num_layers == 0) {
|
if (num_layers == 0) {
|
||||||
@@ -111,29 +114,23 @@ void copy_blocks(
|
|||||||
int64_t key_cache_ptrs[num_layers];
|
int64_t key_cache_ptrs[num_layers];
|
||||||
int64_t value_cache_ptrs[num_layers];
|
int64_t value_cache_ptrs[num_layers];
|
||||||
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
|
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
|
||||||
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
|
key_cache_ptrs[layer_idx] =
|
||||||
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
|
||||||
|
value_cache_ptrs[layer_idx] =
|
||||||
|
reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
||||||
}
|
}
|
||||||
// Create block mapping array.
|
|
||||||
std::vector<int64_t> block_mapping_vec;
|
// block_mapping is a 2D tensor with shape (num_pairs, 2).
|
||||||
for (const auto& pair : block_mapping) {
|
int num_pairs = block_mapping.size(0);
|
||||||
int64_t src_block_number = pair.first;
|
|
||||||
for (int64_t dst_block_number : pair.second) {
|
|
||||||
block_mapping_vec.push_back(src_block_number);
|
|
||||||
block_mapping_vec.push_back(dst_block_number);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
int64_t* block_mapping_array = block_mapping_vec.data();
|
|
||||||
int num_pairs = block_mapping_vec.size() / 2;
|
|
||||||
|
|
||||||
// Move the data structures to the GPU.
|
// Move the data structures to the GPU.
|
||||||
// NOTE: This synchronizes the CPU and GPU.
|
// NOTE: This synchronizes the CPU and GPU.
|
||||||
torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
|
torch::Tensor key_cache_ptrs_tensor =
|
||||||
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
|
||||||
torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
|
.to(cache_device);
|
||||||
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
torch::Tensor value_cache_ptrs_tensor =
|
||||||
torch::Tensor block_mapping_tensor = torch::from_blob(
|
torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
|
||||||
block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
|
.to(cache_device);
|
||||||
|
|
||||||
// Launch the kernel.
|
// Launch the kernel.
|
||||||
const int numel_per_block = key_caches[0][0].numel();
|
const int numel_per_block = key_caches[0][0].numel();
|
||||||
@@ -146,26 +143,23 @@ void copy_blocks(
|
|||||||
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||||
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||||
block_mapping_tensor.data_ptr<int64_t>(),
|
block_mapping.data_ptr<int64_t>(), numel_per_block);
|
||||||
numel_per_block);
|
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template<typename scalar_t, typename cache_t, bool is_fp8_kv_cache>
|
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||||
__global__ void reshape_and_cache_kernel(
|
__global__ void reshape_and_cache_kernel(
|
||||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||||
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
|
||||||
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
// block_size, x]
|
||||||
|
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size,
|
||||||
|
// block_size]
|
||||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||||
const int key_stride,
|
const int key_stride, const int value_stride, const int num_heads,
|
||||||
const int value_stride,
|
const int head_size, const int block_size, const int x,
|
||||||
const int num_heads,
|
|
||||||
const int head_size,
|
|
||||||
const int block_size,
|
|
||||||
const int x,
|
|
||||||
const float kv_scale) {
|
const float kv_scale) {
|
||||||
const int64_t token_idx = blockIdx.x;
|
const int64_t token_idx = blockIdx.x;
|
||||||
const int64_t slot_idx = slot_mapping[token_idx];
|
const int64_t slot_idx = slot_mapping[token_idx];
|
||||||
@@ -187,30 +181,24 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
const int x_idx = head_offset / x;
|
const int x_idx = head_offset / x;
|
||||||
const int x_offset = head_offset % x;
|
const int x_offset = head_offset % x;
|
||||||
|
|
||||||
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
const int64_t tgt_key_idx =
|
||||||
+ head_idx * (head_size / x) * block_size * x
|
block_idx * num_heads * (head_size / x) * block_size * x +
|
||||||
+ x_idx * block_size * x
|
head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
|
||||||
+ block_offset * x
|
block_offset * x + x_offset;
|
||||||
+ x_offset;
|
const int64_t tgt_value_idx =
|
||||||
const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
|
block_idx * num_heads * head_size * block_size +
|
||||||
+ head_idx * head_size * block_size
|
head_idx * head_size * block_size + head_offset * block_size +
|
||||||
+ head_offset * block_size
|
block_offset;
|
||||||
+ block_offset;
|
|
||||||
scalar_t tgt_key = key[src_key_idx];
|
scalar_t tgt_key = key[src_key_idx];
|
||||||
scalar_t tgt_value = value[src_value_idx];
|
scalar_t tgt_value = value[src_value_idx];
|
||||||
if constexpr (is_fp8_kv_cache) {
|
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||||
#if defined(ENABLE_FP8_E5M2)
|
|
||||||
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
|
|
||||||
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
|
|
||||||
#elif defined(ENABLE_FP8_E4M3)
|
|
||||||
key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_key, kv_scale);
|
|
||||||
value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_value, kv_scale);
|
|
||||||
#else
|
|
||||||
assert(false);
|
|
||||||
#endif
|
|
||||||
} else {
|
|
||||||
key_cache[tgt_key_idx] = tgt_key;
|
key_cache[tgt_key_idx] = tgt_key;
|
||||||
value_cache[tgt_value_idx] = tgt_value;
|
value_cache[tgt_value_idx] = tgt_value;
|
||||||
|
} else {
|
||||||
|
key_cache[tgt_key_idx] =
|
||||||
|
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
|
||||||
|
value_cache[tgt_value_idx] =
|
||||||
|
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -219,15 +207,13 @@ template<typename scalar_t>
|
|||||||
__global__ void reshape_and_cache_flash_kernel(
|
__global__ void reshape_and_cache_flash_kernel(
|
||||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||||
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size]
|
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads,
|
||||||
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size]
|
// head_size]
|
||||||
|
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads,
|
||||||
|
// head_size]
|
||||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||||
const int block_stride,
|
const int block_stride, const int key_stride, const int value_stride,
|
||||||
const int key_stride,
|
const int num_heads, const int head_size, const int block_size) {
|
||||||
const int value_stride,
|
|
||||||
const int num_heads,
|
|
||||||
const int head_size,
|
|
||||||
const int block_size) {
|
|
||||||
const int64_t token_idx = blockIdx.x;
|
const int64_t token_idx = blockIdx.x;
|
||||||
const int64_t slot_idx = slot_mapping[token_idx];
|
const int64_t slot_idx = slot_mapping[token_idx];
|
||||||
// NOTE: slot_idx can be -1 if the token is padded
|
// NOTE: slot_idx can be -1 if the token is padded
|
||||||
@@ -242,40 +228,37 @@ __global__ void reshape_and_cache_flash_kernel(
|
|||||||
const int64_t src_value_idx = token_idx * value_stride + i;
|
const int64_t src_value_idx = token_idx * value_stride + i;
|
||||||
const int head_idx = i / head_size;
|
const int head_idx = i / head_size;
|
||||||
const int head_offset = i % head_size;
|
const int head_offset = i % head_size;
|
||||||
const int64_t tgt_value_idx = block_idx * block_stride
|
const int64_t tgt_value_idx = block_idx * block_stride +
|
||||||
+ block_offset * num_heads * head_size
|
block_offset * num_heads * head_size +
|
||||||
+ head_idx * head_size
|
head_idx * head_size + head_offset;
|
||||||
+ head_offset;
|
|
||||||
k_cache[tgt_value_idx] = key[src_key_idx];
|
k_cache[tgt_value_idx] = key[src_key_idx];
|
||||||
v_cache[tgt_value_idx] = value[src_value_idx];
|
v_cache[tgt_value_idx] = value[src_value_idx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
|
// KV_T is the stored data type of kv-cache.
|
||||||
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_KV_CACHE><<<grid, block, 0, stream>>>( \
|
// CACHE_T is the data type of key and value tensors.
|
||||||
|
// KV_DTYPE is the real data type of kv-cache.
|
||||||
|
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
||||||
|
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||||
|
<<<grid, block, 0, stream>>>( \
|
||||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||||
slot_mapping.data_ptr<int64_t>(), \
|
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
|
||||||
key_stride, \
|
num_heads, head_size, block_size, x, kv_scale);
|
||||||
value_stride, \
|
|
||||||
num_heads, \
|
|
||||||
head_size, \
|
|
||||||
block_size, \
|
|
||||||
x, \
|
|
||||||
kv_scale);
|
|
||||||
|
|
||||||
void reshape_and_cache(
|
void reshape_and_cache(
|
||||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
torch::Tensor&
|
||||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
|
torch::Tensor&
|
||||||
|
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
torch::Tensor& slot_mapping, // [num_tokens]
|
torch::Tensor& slot_mapping, // [num_tokens]
|
||||||
const std::string& kv_cache_dtype,
|
const std::string& kv_cache_dtype, const double kv_scale) {
|
||||||
const float kv_scale)
|
|
||||||
{
|
|
||||||
int num_tokens = key.size(0);
|
int num_tokens = key.size(0);
|
||||||
int num_heads = key.size(1);
|
int num_heads = key.size(1);
|
||||||
int head_size = key.size(2);
|
int head_size = key.size(2);
|
||||||
@@ -289,25 +272,9 @@ void reshape_and_cache(
|
|||||||
dim3 block(std::min(num_heads * head_size, 512));
|
dim3 block(std::min(num_heads * head_size, 512));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
if (kv_cache_dtype == "auto") {
|
|
||||||
if (key.dtype() == at::ScalarType::Float) {
|
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
|
||||||
CALL_RESHAPE_AND_CACHE(float, float, false);
|
CALL_RESHAPE_AND_CACHE)
|
||||||
} else if (key.dtype() == at::ScalarType::Half) {
|
|
||||||
CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
|
|
||||||
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
|
||||||
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
|
|
||||||
}
|
|
||||||
} else if (kv_cache_dtype == "fp8") {
|
|
||||||
if (key.dtype() == at::ScalarType::Float) {
|
|
||||||
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
|
|
||||||
} else if (key.dtype() == at::ScalarType::Half) {
|
|
||||||
CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
|
|
||||||
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
|
||||||
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void reshape_and_cache_flash(
|
void reshape_and_cache_flash(
|
||||||
@@ -316,8 +283,7 @@ void reshape_and_cache_flash(
|
|||||||
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
|
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||||
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
|
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||||
torch::Tensor& slot_mapping, // [num_tokens]
|
torch::Tensor& slot_mapping, // [num_tokens]
|
||||||
const std::string& kv_cache_dtype)
|
const std::string& kv_cache_dtype) {
|
||||||
{
|
|
||||||
// FIXME: only support auto datatype, does not support fp8
|
// FIXME: only support auto datatype, does not support fp8
|
||||||
if (kv_cache_dtype != "auto") {
|
if (kv_cache_dtype != "auto") {
|
||||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||||
@@ -337,62 +303,46 @@ void reshape_and_cache_flash(
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
key.scalar_type(),
|
key.scalar_type(), "reshape_and_cache_flash", [&] {
|
||||||
"reshape_and_cache_flash",
|
vllm::reshape_and_cache_flash_kernel<scalar_t>
|
||||||
[&] {
|
<<<grid, block, 0, stream>>>(
|
||||||
vllm::reshape_and_cache_flash_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||||
key.data_ptr<scalar_t>(),
|
k_cache.data_ptr<scalar_t>(), v_cache.data_ptr<scalar_t>(),
|
||||||
value.data_ptr<scalar_t>(),
|
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride,
|
||||||
k_cache.data_ptr<scalar_t>(),
|
value_stride, num_heads, head_size, block_size);
|
||||||
v_cache.data_ptr<scalar_t>(),
|
|
||||||
slot_mapping.data_ptr<int64_t>(),
|
|
||||||
block_stride,
|
|
||||||
key_stride,
|
|
||||||
value_stride,
|
|
||||||
num_heads,
|
|
||||||
head_size,
|
|
||||||
block_size);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template<typename Tout, typename Tin>
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||||
__global__ void convert_fp8_kernel(
|
__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
|
||||||
const Tin* __restrict__ src_cache,
|
|
||||||
Tout* __restrict__ dst_cache,
|
Tout* __restrict__ dst_cache,
|
||||||
|
const float kv_scale,
|
||||||
const int64_t block_stride) {
|
const int64_t block_stride) {
|
||||||
const int64_t block_idx = blockIdx.x;
|
const int64_t block_idx = blockIdx.x;
|
||||||
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
||||||
int64_t idx = block_idx * block_stride + i;
|
int64_t idx = block_idx * block_stride + i;
|
||||||
#if defined(ENABLE_FP8_E5M2)
|
dst_cache[idx] =
|
||||||
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
|
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
|
||||||
#elif defined(ENABLE_FP8_E4M3)
|
|
||||||
dst_cache[idx] = fp8_e4m3::vec_conversion<Tout, Tin>(src_cache[idx]);
|
|
||||||
#else
|
|
||||||
assert(false);
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
#define CALL_CONVERT_FP8(Tout, Tin) \
|
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
|
||||||
vllm::convert_fp8_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
|
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
|
||||||
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
||||||
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
|
reinterpret_cast<Tout*>(dst_cache.data_ptr()), kv_scale, block_stride);
|
||||||
block_stride);
|
|
||||||
|
|
||||||
void convert_fp8(
|
// Only for testing.
|
||||||
torch::Tensor& src_cache,
|
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||||
torch::Tensor& dst_cache)
|
const double kv_scale, const std::string& kv_cache_dtype) {
|
||||||
{
|
|
||||||
torch::Device src_device = src_cache.device();
|
torch::Device src_device = src_cache.device();
|
||||||
torch::Device dst_device = dst_cache.device();
|
torch::Device dst_device = dst_cache.device();
|
||||||
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
|
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
|
||||||
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
|
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(src_device.index() == dst_device.index(),
|
||||||
src_device.index() == dst_device.index(),
|
|
||||||
"src and dst must be on the same GPU");
|
"src and dst must be on the same GPU");
|
||||||
at::cuda::OptionalCUDAGuard device_guard(src_device);
|
at::cuda::OptionalCUDAGuard device_guard(src_device);
|
||||||
|
|
||||||
@@ -403,17 +353,37 @@ void convert_fp8(
|
|||||||
dim3 block(std::min(block_stride, int64_t(512)));
|
dim3 block(std::min(block_stride, int64_t(512)));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
if (kv_cache_dtype == "auto") {
|
||||||
if (src_cache.dtype() == at::ScalarType::Float) {
|
if (src_cache.dtype() == at::ScalarType::Float) {
|
||||||
CALL_CONVERT_FP8(uint8_t, float);
|
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto);
|
||||||
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
||||||
CALL_CONVERT_FP8(uint8_t, uint16_t);
|
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);
|
||||||
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
||||||
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16);
|
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);
|
||||||
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
||||||
CALL_CONVERT_FP8(float, uint8_t);
|
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
|
||||||
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
||||||
CALL_CONVERT_FP8(uint16_t, uint8_t);
|
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
|
||||||
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
||||||
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t);
|
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
|
||||||
|
}
|
||||||
|
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
|
||||||
|
if (src_cache.dtype() == at::ScalarType::Float) {
|
||||||
|
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
|
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
||||||
|
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
|
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
|
||||||
|
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
|
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
||||||
|
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
|
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
||||||
|
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
|
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
|
||||||
|
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -81,12 +81,10 @@ void silu_and_mul(torch::Tensor &out, torch::Tensor &input) {
|
|||||||
int num_tokens = input.numel() / input.size(-1);
|
int num_tokens = input.numel() / input.size(-1);
|
||||||
int d = input.size(-1) / 2;
|
int d = input.size(-1) / 2;
|
||||||
|
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] {
|
||||||
input.scalar_type(), "silu_and_mul_impl", [&] {
|
|
||||||
CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
|
CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
|
||||||
activation_kernel<scalar_t, silu_act, true>(num_tokens, d,
|
activation_kernel<scalar_t, silu_act, true>(
|
||||||
input.data_ptr<scalar_t>(),
|
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
|
||||||
out.data_ptr<scalar_t>());
|
|
||||||
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
|
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -97,12 +95,10 @@ void gelu_and_mul(torch::Tensor &out, // [..., d]
|
|||||||
int num_tokens = input.numel() / input.size(-1);
|
int num_tokens = input.numel() / input.size(-1);
|
||||||
int d = input.size(-1) / 2;
|
int d = input.size(-1) / 2;
|
||||||
|
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] {
|
||||||
input.scalar_type(), "gelu_and_mul_impl", [&] {
|
|
||||||
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
|
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
|
||||||
activation_kernel<scalar_t, gelu_act, true>(num_tokens, d,
|
activation_kernel<scalar_t, gelu_act, true>(
|
||||||
input.data_ptr<scalar_t>(),
|
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
|
||||||
out.data_ptr<scalar_t>());
|
|
||||||
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
|
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename scalar_t> struct KernelVecType {
|
template <typename scalar_t>
|
||||||
|
struct KernelVecType {
|
||||||
using q_load_vec_type = void;
|
using q_load_vec_type = void;
|
||||||
using q_vec_type = void;
|
using q_vec_type = void;
|
||||||
using k_load_vec_type = void;
|
using k_load_vec_type = void;
|
||||||
@@ -11,7 +12,8 @@ template <typename scalar_t> struct KernelVecType {
|
|||||||
using v_load_vec_type = void;
|
using v_load_vec_type = void;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <> struct KernelVecType<float> {
|
template <>
|
||||||
|
struct KernelVecType<float> {
|
||||||
using q_load_vec_type = vec_op::FP32Vec4;
|
using q_load_vec_type = vec_op::FP32Vec4;
|
||||||
using q_vec_type = vec_op::FP32Vec16;
|
using q_vec_type = vec_op::FP32Vec16;
|
||||||
using k_load_vec_type = vec_op::FP32Vec16;
|
using k_load_vec_type = vec_op::FP32Vec16;
|
||||||
@@ -21,7 +23,8 @@ template <> struct KernelVecType<float> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
#ifdef __AVX512BF16__
|
#ifdef __AVX512BF16__
|
||||||
template <> struct KernelVecType<c10::BFloat16> {
|
template <>
|
||||||
|
struct KernelVecType<c10::BFloat16> {
|
||||||
using q_load_vec_type = vec_op::BF16Vec8;
|
using q_load_vec_type = vec_op::BF16Vec8;
|
||||||
using q_vec_type = vec_op::BF16Vec32;
|
using q_vec_type = vec_op::BF16Vec32;
|
||||||
using k_load_vec_type = vec_op::BF16Vec32;
|
using k_load_vec_type = vec_op::BF16Vec32;
|
||||||
@@ -30,7 +33,8 @@ template <> struct KernelVecType<c10::BFloat16> {
|
|||||||
using v_load_vec_type = vec_op::BF16Vec16;
|
using v_load_vec_type = vec_op::BF16Vec16;
|
||||||
};
|
};
|
||||||
#else
|
#else
|
||||||
template <> struct KernelVecType<c10::BFloat16> {
|
template <>
|
||||||
|
struct KernelVecType<c10::BFloat16> {
|
||||||
using q_load_vec_type = vec_op::BF16Vec8;
|
using q_load_vec_type = vec_op::BF16Vec8;
|
||||||
using q_vec_type = vec_op::FP32Vec16;
|
using q_vec_type = vec_op::FP32Vec16;
|
||||||
using k_load_vec_type = vec_op::BF16Vec16;
|
using k_load_vec_type = vec_op::BF16Vec16;
|
||||||
@@ -67,9 +71,10 @@ FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
FORCE_INLINE std::pair<T, T>
|
FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size,
|
||||||
reduceSoftmaxAlibi(T *data, const int size, const int capacity,
|
const int capacity,
|
||||||
const float alibi_slope, const int start_index,
|
const float alibi_slope,
|
||||||
|
const int start_index,
|
||||||
const int seq_len) {
|
const int seq_len) {
|
||||||
data[0] += alibi_slope * (start_index - seq_len + 1);
|
data[0] += alibi_slope * (start_index - seq_len + 1);
|
||||||
T max = data[0];
|
T max = data[0];
|
||||||
@@ -215,16 +220,16 @@ FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block,
|
|||||||
namespace {
|
namespace {
|
||||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
|
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
|
||||||
struct paged_attention_v1_impl {
|
struct paged_attention_v1_impl {
|
||||||
static void
|
static void call(
|
||||||
call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||||
// head_size/x, block_size, x]
|
// head_size/x, block_size, x]
|
||||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||||
// head_size, block_size]
|
// head_size, block_size]
|
||||||
const int num_kv_heads, const float scale,
|
const int num_kv_heads, const float scale,
|
||||||
const int
|
const int* __restrict__ block_tables, // [num_seqs,
|
||||||
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
// max_num_blocks_per_seq]
|
||||||
const int* __restrict__ seq_lens, // [num_seqs]
|
const int* __restrict__ seq_lens, // [num_seqs]
|
||||||
const int max_num_blocks_per_seq,
|
const int max_num_blocks_per_seq,
|
||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
@@ -257,8 +262,7 @@ struct paged_attention_v1_impl {
|
|||||||
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
|
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
|
||||||
const scalar_t* __restrict__ q_vec_ptr =
|
const scalar_t* __restrict__ q_vec_ptr =
|
||||||
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||||
const int last_block_token_num =
|
const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE;
|
||||||
seq_len - (block_num - 1) * BLOCK_SIZE;
|
|
||||||
float* __restrict__ thread_block_logits =
|
float* __restrict__ thread_block_logits =
|
||||||
logits + omp_get_thread_num() * max_seq_len_padded;
|
logits + omp_get_thread_num() * max_seq_len_padded;
|
||||||
|
|
||||||
@@ -282,8 +286,7 @@ struct paged_attention_v1_impl {
|
|||||||
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
|
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
|
||||||
seq_len);
|
seq_len);
|
||||||
} else {
|
} else {
|
||||||
reduceSoftmax(thread_block_logits, seq_len,
|
reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE);
|
||||||
block_num * BLOCK_SIZE);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute value
|
// Compute value
|
||||||
@@ -348,8 +351,8 @@ template <typename T, int BLOCK_SIZE>
|
|||||||
void paged_attention_v1_impl_launcher(
|
void paged_attention_v1_impl_launcher(
|
||||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||||
torch::Tensor &block_tables, torch::Tensor &seq_lens,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||||
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) {
|
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||||
int num_seqs = query.size(0);
|
int num_seqs = query.size(0);
|
||||||
int num_heads = query.size(1);
|
int num_heads = query.size(1);
|
||||||
int head_size = query.size(2);
|
int head_size = query.size(2);
|
||||||
@@ -387,6 +390,9 @@ void paged_attention_v1_impl_launcher(
|
|||||||
case 128:
|
case 128:
|
||||||
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
||||||
break;
|
break;
|
||||||
|
case 192:
|
||||||
|
LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
|
||||||
|
break;
|
||||||
case 256:
|
case 256:
|
||||||
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
||||||
break;
|
break;
|
||||||
@@ -412,15 +418,18 @@ void paged_attention_v1_impl_launcher(
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
|
void paged_attention_v1(
|
||||||
torch::Tensor &key_cache, torch::Tensor &value_cache,
|
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
int num_kv_heads, float scale,
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||||
torch::Tensor &block_tables,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||||
torch::Tensor &seq_lens, int block_size,
|
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
int max_seq_len,
|
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
||||||
const c10::optional<torch::Tensor> &alibi_slopes,
|
const int64_t blocksparse_local_blocks,
|
||||||
const std::string &kv_cache_dtype, float kv_scale) {
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
|
const int64_t blocksparse_head_sliding_step) {
|
||||||
TORCH_CHECK(kv_scale == 1.0f);
|
TORCH_CHECK(kv_scale == 1.0f);
|
||||||
|
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||||
|
"CPU backend does not support blocksparse attention yet.");
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
||||||
[&] {
|
[&] {
|
||||||
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
|
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
|
||||||
@@ -435,9 +444,10 @@ template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
|
|||||||
struct paged_attention_v2_impl {
|
struct paged_attention_v2_impl {
|
||||||
static void call(
|
static void call(
|
||||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ exp_sums, // [num_seqs, num_heads,
|
||||||
float
|
// max_num_partitions]
|
||||||
*__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ max_logits, // [num_seqs, num_heads,
|
||||||
|
// max_num_partitions]
|
||||||
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
||||||
// max_num_partitions, head_size]
|
// max_num_partitions, head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
@@ -446,8 +456,8 @@ struct paged_attention_v2_impl {
|
|||||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||||
// head_size, block_size]
|
// head_size, block_size]
|
||||||
const int num_kv_heads, const float scale,
|
const int num_kv_heads, const float scale,
|
||||||
const int
|
const int* __restrict__ block_tables, // [num_seqs,
|
||||||
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
// max_num_blocks_per_seq]
|
||||||
const int* __restrict__ seq_lens, // [num_seqs]
|
const int* __restrict__ seq_lens, // [num_seqs]
|
||||||
const int max_num_blocks_per_seq,
|
const int max_num_blocks_per_seq,
|
||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
@@ -468,8 +478,7 @@ struct paged_attention_v2_impl {
|
|||||||
const int seq_len = seq_lens[seq_idx];
|
const int seq_len = seq_lens[seq_idx];
|
||||||
const int start_token_idx = partition_idx * PARTITION_SIZE;
|
const int start_token_idx = partition_idx * PARTITION_SIZE;
|
||||||
|
|
||||||
if (start_token_idx >= seq_len)
|
if (start_token_idx >= seq_len) continue;
|
||||||
continue;
|
|
||||||
|
|
||||||
const int partition_num =
|
const int partition_num =
|
||||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||||
@@ -477,8 +486,7 @@ struct paged_attention_v2_impl {
|
|||||||
const int token_num =
|
const int token_num =
|
||||||
(std::min(seq_len, start_token_idx + PARTITION_SIZE) -
|
(std::min(seq_len, start_token_idx + PARTITION_SIZE) -
|
||||||
start_token_idx);
|
start_token_idx);
|
||||||
const int block_num =
|
const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
(token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
|
||||||
const int last_block_token_num =
|
const int last_block_token_num =
|
||||||
token_num - (block_num - 1) * BLOCK_SIZE;
|
token_num - (block_num - 1) * BLOCK_SIZE;
|
||||||
const int* seq_block_table = block_tables +
|
const int* seq_block_table = block_tables +
|
||||||
@@ -510,8 +518,8 @@ struct paged_attention_v2_impl {
|
|||||||
logits, token_num, block_num * BLOCK_SIZE,
|
logits, token_num, block_num * BLOCK_SIZE,
|
||||||
alibi_slopes[head_idx], start_token_idx, seq_len);
|
alibi_slopes[head_idx], start_token_idx, seq_len);
|
||||||
} else {
|
} else {
|
||||||
max_and_sum = reduceSoftmax(logits, token_num,
|
max_and_sum =
|
||||||
block_num * BLOCK_SIZE);
|
reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto&& [max_logit, exp_sum] = max_and_sum;
|
auto&& [max_logit, exp_sum] = max_and_sum;
|
||||||
@@ -587,8 +595,7 @@ struct paged_attention_v2_impl {
|
|||||||
const int partition_num =
|
const int partition_num =
|
||||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||||
|
|
||||||
if (partition_num == 1)
|
if (partition_num == 1) continue;
|
||||||
continue;
|
|
||||||
|
|
||||||
reducePartitonSoftmax(
|
reducePartitonSoftmax(
|
||||||
max_logits + seq_idx * num_heads * max_num_partitions +
|
max_logits + seq_idx * num_heads * max_num_partitions +
|
||||||
@@ -603,8 +610,8 @@ struct paged_attention_v2_impl {
|
|||||||
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
||||||
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
|
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
|
||||||
constexpr int head_elem_num_per_group =
|
constexpr int head_elem_num_per_group =
|
||||||
16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE
|
16; // Note: didn't align with the cacheline size, due to some
|
||||||
// didn't align with 64 bytes
|
// HEAD_SIZE didn't align with 64 bytes
|
||||||
static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
|
static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
|
||||||
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
|
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
|
||||||
const float* __restrict__ rescale_factors = exp_sums;
|
const float* __restrict__ rescale_factors = exp_sums;
|
||||||
@@ -616,8 +623,7 @@ struct paged_attention_v2_impl {
|
|||||||
const int partition_num =
|
const int partition_num =
|
||||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||||
|
|
||||||
if (partition_num == 1)
|
if (partition_num == 1) continue;
|
||||||
continue;
|
|
||||||
|
|
||||||
const float* __restrict__ seq_head_rescale_factors =
|
const float* __restrict__ seq_head_rescale_factors =
|
||||||
rescale_factors + seq_idx * num_heads * max_num_partitions +
|
rescale_factors + seq_idx * num_heads * max_num_partitions +
|
||||||
@@ -701,6 +707,9 @@ void paged_attention_v2_impl_launcher(
|
|||||||
case 128:
|
case 128:
|
||||||
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
||||||
break;
|
break;
|
||||||
|
case 192:
|
||||||
|
LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
|
||||||
|
break;
|
||||||
case 256:
|
case 256:
|
||||||
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
||||||
break;
|
break;
|
||||||
@@ -713,8 +722,8 @@ void paged_attention_v2_impl_launcher(
|
|||||||
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
||||||
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
|
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
|
||||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||||
num_kv_heads, scale, block_tables, seq_lens, block_size, \
|
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \
|
||||||
max_seq_len, alibi_slopes);
|
alibi_slopes);
|
||||||
|
|
||||||
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||||
switch (block_size) { \
|
switch (block_size) { \
|
||||||
@@ -727,16 +736,19 @@ void paged_attention_v2_impl_launcher(
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
|
void paged_attention_v2(
|
||||||
torch::Tensor &max_logits, torch::Tensor &tmp_out,
|
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||||
torch::Tensor &query, torch::Tensor &key_cache,
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor &value_cache, int num_kv_heads,
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||||
float scale, torch::Tensor &block_tables,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||||
torch::Tensor &seq_lens, int block_size,
|
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
int max_seq_len,
|
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
||||||
const c10::optional<torch::Tensor> &alibi_slopes,
|
const int64_t blocksparse_local_blocks,
|
||||||
const std::string &kv_cache_dtype, float kv_scale) {
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
|
const int64_t blocksparse_head_sliding_step) {
|
||||||
TORCH_CHECK(kv_scale == 1.0f);
|
TORCH_CHECK(kv_scale == 1.0f);
|
||||||
|
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||||
|
"CPU backend does not support blocksparse attention yet.");
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
||||||
[&] {
|
[&] {
|
||||||
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
|
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
|
||||||
|
|||||||
@@ -5,19 +5,20 @@
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void copy_blocks_cpu_impl(
|
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
|
||||||
std::vector<torch::Tensor> &key_caches,
|
std::vector<torch::Tensor> const& value_caches,
|
||||||
std::vector<torch::Tensor> &value_caches,
|
const torch::Tensor& mapping_pairs,
|
||||||
const std::vector<std::pair<int64_t, int64_t>> mapping_pairs,
|
const int element_num_per_block,
|
||||||
const int element_num_per_block, const int layer_num) {
|
const int layer_num) {
|
||||||
const size_t pair_num = mapping_pairs.size();
|
const size_t pair_num = mapping_pairs.size(0);
|
||||||
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
|
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
|
||||||
#pragma omp parallel for collapse(2)
|
#pragma omp parallel for collapse(2)
|
||||||
for (int layer = 0; layer < layer_num; ++layer) {
|
for (int layer = 0; layer < layer_num; ++layer) {
|
||||||
for (size_t pair = 0; pair < pair_num; ++pair) {
|
for (size_t pair = 0; pair < pair_num; ++pair) {
|
||||||
int64_t source_offset = element_num_per_block * mapping_pairs[pair].first;
|
int64_t source_offset =
|
||||||
|
element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
|
||||||
int64_t target_offset =
|
int64_t target_offset =
|
||||||
element_num_per_block * mapping_pairs[pair].second;
|
element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
|
||||||
scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
|
scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
|
||||||
scalar_t* source_ptr = key_cache_ptr + source_offset;
|
scalar_t* source_ptr = key_cache_ptr + source_offset;
|
||||||
scalar_t* target_ptr = key_cache_ptr + target_offset;
|
scalar_t* target_ptr = key_cache_ptr + target_offset;
|
||||||
@@ -81,28 +82,23 @@ void reshape_and_cache_cpu_impl(
|
|||||||
}
|
}
|
||||||
}; // namespace
|
}; // namespace
|
||||||
|
|
||||||
void copy_blocks(std::vector<torch::Tensor> &key_caches,
|
// Note: the key_caches and value_caches vectors are constant but
|
||||||
std::vector<torch::Tensor> &value_caches,
|
// not the Tensors they contain. The vectors need to be const refs
|
||||||
const std::map<int64_t, std::vector<int64_t>> &block_mapping) {
|
// in order to satisfy pytorch's C++ operator registration code.
|
||||||
int num_layers = key_caches.size();
|
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||||
|
std::vector<torch::Tensor> const& value_caches,
|
||||||
|
const torch::Tensor& block_mapping) {
|
||||||
|
unsigned num_layers = key_caches.size();
|
||||||
TORCH_CHECK(num_layers == value_caches.size());
|
TORCH_CHECK(num_layers == value_caches.size());
|
||||||
if (num_layers == 0) {
|
if (num_layers == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<int64_t, int64_t>> mapping_pairs;
|
|
||||||
mapping_pairs.reserve(block_mapping.size());
|
|
||||||
for (const auto &pair : block_mapping) {
|
|
||||||
for (const auto &dst : pair.second) {
|
|
||||||
mapping_pairs.emplace_back(pair.first, dst);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const int element_num_per_block = key_caches[0][0].numel();
|
const int element_num_per_block = key_caches[0][0].numel();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
|
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
|
||||||
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
|
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
|
||||||
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, mapping_pairs,
|
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
|
||||||
element_num_per_block, num_layers);
|
element_num_per_block, num_layers);
|
||||||
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
|
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
|
||||||
});
|
});
|
||||||
@@ -111,7 +107,7 @@ void copy_blocks(std::vector<torch::Tensor> &key_caches,
|
|||||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||||
torch::Tensor& slot_mapping,
|
torch::Tensor& slot_mapping,
|
||||||
const std::string &kv_cache_dtype, float kv_scale) {
|
const std::string& kv_cache_dtype, double kv_scale) {
|
||||||
TORCH_CHECK(kv_scale == 1.0f);
|
TORCH_CHECK(kv_scale == 1.0f);
|
||||||
|
|
||||||
int num_tokens = key.size(0);
|
int num_tokens = key.size(0);
|
||||||
@@ -136,6 +132,6 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||||
const std::map<int64_t, int64_t> &block_mapping) {
|
const torch::Tensor& block_mapping) {
|
||||||
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
|
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
#define CPU_TYPES_HPP
|
#define CPU_TYPES_HPP
|
||||||
|
|
||||||
#include <immintrin.h>
|
#include <immintrin.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
namespace vec_op {
|
namespace vec_op {
|
||||||
|
|
||||||
|
|||||||
@@ -87,8 +87,8 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void rms_norm(torch::Tensor &out, torch::Tensor &input,
|
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||||
torch::Tensor &weight, float epsilon) {
|
double epsilon) {
|
||||||
int hidden_size = input.size(-1);
|
int hidden_size = input.size(-1);
|
||||||
int num_tokens = input.numel() / hidden_size;
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
@@ -102,7 +102,7 @@ void rms_norm(torch::Tensor &out, torch::Tensor &input,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||||
torch::Tensor &weight, float epsilon) {
|
torch::Tensor& weight, double epsilon) {
|
||||||
int hidden_size = input.size(-1);
|
int hidden_size = input.size(-1);
|
||||||
int num_tokens = input.numel() / hidden_size;
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
|
|||||||
@@ -4,25 +4,74 @@
|
|||||||
namespace {
|
namespace {
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void rotary_embedding_impl(
|
void rotary_embedding_impl(
|
||||||
const int64_t
|
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||||
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
// [num_tokens]
|
||||||
scalar_t
|
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
|
||||||
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or
|
/// head_size] or [num_tokens, num_heads,
|
||||||
/// [num_tokens, num_heads, head_size]
|
/// head_size]
|
||||||
scalar_t
|
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||||
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or
|
// head_size] or [num_tokens, num_kv_heads,
|
||||||
// [num_tokens, num_kv_heads, head_size]
|
// head_size]
|
||||||
const scalar_t
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||||
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
// 2]
|
||||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||||
const int num_heads, const int num_kv_heads, const int head_size,
|
const int num_heads, const int num_kv_heads, const int head_size,
|
||||||
const int num_tokens) {
|
const int num_tokens) {
|
||||||
using scalar_vec_t = vec_op::vec_t<scalar_t>;
|
using scalar_vec_t = vec_op::vec_t<scalar_t>;
|
||||||
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
||||||
constexpr int ELEM_SIZE = sizeof(scalar_t);
|
|
||||||
|
|
||||||
const int embed_dim = rot_dim / 2;
|
const int embed_dim = rot_dim / 2;
|
||||||
TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0);
|
bool flag = (embed_dim % VEC_ELEM_NUM == 0);
|
||||||
|
const int loop_upper = flag ? embed_dim : embed_dim - VEC_ELEM_NUM;
|
||||||
|
|
||||||
|
auto compute_loop = [&](const int64_t token_head, const scalar_t* cache_ptr,
|
||||||
|
scalar_t* qk) {
|
||||||
|
int j = 0;
|
||||||
|
for (; j < loop_upper; j += VEC_ELEM_NUM) {
|
||||||
|
const int rot_offset = j;
|
||||||
|
const int x_index = rot_offset;
|
||||||
|
const int y_index = embed_dim + rot_offset;
|
||||||
|
|
||||||
|
const int64_t out_x = token_head + x_index;
|
||||||
|
const int64_t out_y = token_head + y_index;
|
||||||
|
|
||||||
|
const scalar_vec_t cos(cache_ptr + x_index);
|
||||||
|
const scalar_vec_t sin(cache_ptr + y_index);
|
||||||
|
|
||||||
|
const scalar_vec_t q_x(qk + out_x);
|
||||||
|
const scalar_vec_t q_y(qk + out_y);
|
||||||
|
|
||||||
|
vec_op::FP32Vec8 fp32_cos(cos);
|
||||||
|
vec_op::FP32Vec8 fp32_sin(sin);
|
||||||
|
|
||||||
|
vec_op::FP32Vec8 fp32_q_x(q_x);
|
||||||
|
vec_op::FP32Vec8 fp32_q_y(q_y);
|
||||||
|
|
||||||
|
auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
|
||||||
|
scalar_vec_t(out1).save(qk + out_x);
|
||||||
|
|
||||||
|
auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
|
||||||
|
scalar_vec_t(out2).save(qk + out_y);
|
||||||
|
}
|
||||||
|
if (!flag) {
|
||||||
|
for (; j < embed_dim; ++j) {
|
||||||
|
const int x_index = j;
|
||||||
|
const int y_index = embed_dim + j;
|
||||||
|
|
||||||
|
const int64_t out_x = token_head + x_index;
|
||||||
|
const int64_t out_y = token_head + y_index;
|
||||||
|
|
||||||
|
const float fp32_cos = cache_ptr[x_index];
|
||||||
|
const float fp32_sin = cache_ptr[y_index];
|
||||||
|
|
||||||
|
const float fp32_q_x = qk[out_x];
|
||||||
|
const float fp32_q_y = qk[out_y];
|
||||||
|
|
||||||
|
qk[out_x] = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
|
||||||
|
qk[out_y] = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
@@ -33,78 +82,29 @@ void rotary_embedding_impl(
|
|||||||
const int head_idx = i;
|
const int head_idx = i;
|
||||||
const int64_t token_head =
|
const int64_t token_head =
|
||||||
token_idx * query_stride + head_idx * head_size;
|
token_idx * query_stride + head_idx * head_size;
|
||||||
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) {
|
compute_loop(token_head, cache_ptr, query);
|
||||||
const int rot_offset = j;
|
|
||||||
const int x_index = rot_offset;
|
|
||||||
const int y_index = embed_dim + rot_offset;
|
|
||||||
|
|
||||||
const int64_t out_x = token_head + x_index;
|
|
||||||
const int64_t out_y = token_head + y_index;
|
|
||||||
|
|
||||||
const scalar_vec_t cos(cache_ptr + x_index);
|
|
||||||
const scalar_vec_t sin(cache_ptr + y_index);
|
|
||||||
|
|
||||||
const scalar_vec_t q_x(query + out_x);
|
|
||||||
const scalar_vec_t q_y(query + out_y);
|
|
||||||
|
|
||||||
vec_op::FP32Vec8 fp32_cos(cos);
|
|
||||||
vec_op::FP32Vec8 fp32_sin(sin);
|
|
||||||
|
|
||||||
vec_op::FP32Vec8 fp32_q_x(q_x);
|
|
||||||
vec_op::FP32Vec8 fp32_q_y(q_y);
|
|
||||||
|
|
||||||
auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
|
|
||||||
scalar_vec_t(out1).save(query + out_x);
|
|
||||||
|
|
||||||
auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
|
|
||||||
scalar_vec_t(out2).save(query + out_y);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < num_kv_heads; ++i) {
|
for (int i = 0; i < num_kv_heads; ++i) {
|
||||||
const int head_idx = i;
|
const int head_idx = i;
|
||||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||||
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) {
|
compute_loop(token_head, cache_ptr, key);
|
||||||
const int rot_offset = j;
|
|
||||||
const int x_index = rot_offset;
|
|
||||||
const int y_index = embed_dim + rot_offset;
|
|
||||||
|
|
||||||
const int64_t out_x = token_head + x_index;
|
|
||||||
const int64_t out_y = token_head + y_index;
|
|
||||||
|
|
||||||
const scalar_vec_t cos(cache_ptr + x_index);
|
|
||||||
const scalar_vec_t sin(cache_ptr + y_index);
|
|
||||||
|
|
||||||
const scalar_vec_t k_x(key + out_x);
|
|
||||||
const scalar_vec_t k_y(key + out_y);
|
|
||||||
|
|
||||||
vec_op::FP32Vec8 fp32_cos(cos);
|
|
||||||
vec_op::FP32Vec8 fp32_sin(sin);
|
|
||||||
|
|
||||||
vec_op::FP32Vec8 fp32_k_x(k_x);
|
|
||||||
vec_op::FP32Vec8 fp32_k_y(k_y);
|
|
||||||
|
|
||||||
auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin;
|
|
||||||
scalar_vec_t(out1).save(key + out_x);
|
|
||||||
auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin;
|
|
||||||
scalar_vec_t(out2).save(key + out_y);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void rotary_embedding_gptj_impl(
|
void rotary_embedding_gptj_impl(
|
||||||
const int64_t
|
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||||
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
// [num_tokens]
|
||||||
scalar_t
|
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
|
||||||
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or
|
/// head_size] or [num_tokens, num_heads,
|
||||||
/// [num_tokens, num_heads, head_size]
|
/// head_size]
|
||||||
scalar_t
|
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||||
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or
|
// head_size] or [num_tokens, num_kv_heads,
|
||||||
// [num_tokens, num_kv_heads, head_size]
|
// head_size]
|
||||||
const scalar_t
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||||
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
// 2]
|
||||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||||
const int num_heads, const int num_kv_heads, const int head_size,
|
const int num_heads, const int num_kv_heads, const int head_size,
|
||||||
const int num_tokens) {
|
const int num_tokens) {
|
||||||
@@ -168,7 +168,7 @@ void rotary_embedding_gptj_impl(
|
|||||||
}; // namespace
|
}; // namespace
|
||||||
|
|
||||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||||
torch::Tensor &key, int head_size,
|
torch::Tensor& key, int64_t head_size,
|
||||||
torch::Tensor& cos_sin_cache, bool is_neox) {
|
torch::Tensor& cos_sin_cache, bool is_neox) {
|
||||||
int num_tokens = query.numel() / query.size(-1);
|
int num_tokens = query.numel() / query.size(-1);
|
||||||
int rot_dim = cos_sin_cache.size(1);
|
int rot_dim = cos_sin_cache.size(1);
|
||||||
|
|||||||
@@ -1,73 +0,0 @@
|
|||||||
#include "cache.h"
|
|
||||||
#include "cuda_utils.h"
|
|
||||||
#include "ops.h"
|
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
// vLLM custom ops
|
|
||||||
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
|
||||||
|
|
||||||
// Attention ops
|
|
||||||
ops.def(
|
|
||||||
"paged_attention_v1",
|
|
||||||
&paged_attention_v1,
|
|
||||||
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
|
||||||
ops.def(
|
|
||||||
"paged_attention_v2",
|
|
||||||
&paged_attention_v2,
|
|
||||||
"PagedAttention V2.");
|
|
||||||
|
|
||||||
// Activation ops
|
|
||||||
ops.def(
|
|
||||||
"silu_and_mul",
|
|
||||||
&silu_and_mul,
|
|
||||||
"Activation function used in SwiGLU.");
|
|
||||||
ops.def(
|
|
||||||
"gelu_and_mul",
|
|
||||||
&gelu_and_mul,
|
|
||||||
"Activation function used in GeGLU with `none` approximation.");
|
|
||||||
ops.def(
|
|
||||||
"gelu_tanh_and_mul",
|
|
||||||
&gelu_tanh_and_mul,
|
|
||||||
"Activation function used in GeGLU with `tanh` approximation.");
|
|
||||||
ops.def(
|
|
||||||
"gelu_new",
|
|
||||||
&gelu_new,
|
|
||||||
"GELU implementation used in GPT-2.");
|
|
||||||
ops.def(
|
|
||||||
"gelu_fast",
|
|
||||||
&gelu_fast,
|
|
||||||
"Approximate GELU implementation.");
|
|
||||||
|
|
||||||
// Layernorm
|
|
||||||
ops.def(
|
|
||||||
"rms_norm",
|
|
||||||
&rms_norm,
|
|
||||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
|
||||||
|
|
||||||
ops.def(
|
|
||||||
"fused_add_rms_norm",
|
|
||||||
&fused_add_rms_norm,
|
|
||||||
"In-place fused Add and RMS Normalization");
|
|
||||||
|
|
||||||
// Rotary embedding
|
|
||||||
ops.def(
|
|
||||||
"rotary_embedding",
|
|
||||||
&rotary_embedding,
|
|
||||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
|
||||||
|
|
||||||
// Cache ops
|
|
||||||
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
|
||||||
cache_ops.def(
|
|
||||||
"swap_blocks",
|
|
||||||
&swap_blocks,
|
|
||||||
"Swap in (out) the cache blocks from src to dst");
|
|
||||||
cache_ops.def(
|
|
||||||
"copy_blocks",
|
|
||||||
©_blocks,
|
|
||||||
"Copy the cache blocks from src to dst");
|
|
||||||
cache_ops.def(
|
|
||||||
"reshape_and_cache",
|
|
||||||
&reshape_and_cache,
|
|
||||||
"Reshape the key and value tensors and cache them");
|
|
||||||
}
|
|
||||||
106
csrc/cpu/torch_bindings.cpp
Normal file
106
csrc/cpu/torch_bindings.cpp
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
#include "cache.h"
|
||||||
|
#include "ops.h"
|
||||||
|
#include "registration.h"
|
||||||
|
|
||||||
|
#include <torch/library.h>
|
||||||
|
|
||||||
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||||
|
// vLLM custom ops
|
||||||
|
|
||||||
|
// Attention ops
|
||||||
|
// Compute the attention between an input query and the cached keys/values
|
||||||
|
// using PagedAttention.
|
||||||
|
ops.def(
|
||||||
|
"paged_attention_v1("
|
||||||
|
" Tensor! out, Tensor query, Tensor key_cache,"
|
||||||
|
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||||
|
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||||
|
" int max_seq_len, Tensor? alibi_slopes,"
|
||||||
|
" str kv_cache_dtype, float kv_scale, int tp_rank,"
|
||||||
|
" int blocksparse_local_blocks,"
|
||||||
|
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||||
|
" int blocksparse_head_sliding_step) -> ()");
|
||||||
|
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
|
||||||
|
|
||||||
|
// PagedAttention V2.
|
||||||
|
ops.def(
|
||||||
|
"paged_attention_v2("
|
||||||
|
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
|
||||||
|
" Tensor tmp_out, Tensor query, Tensor key_cache,"
|
||||||
|
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||||
|
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||||
|
" int max_seq_len, Tensor? alibi_slopes,"
|
||||||
|
" str kv_cache_dtype, float kv_scale, int tp_rank,"
|
||||||
|
" int blocksparse_local_blocks,"
|
||||||
|
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||||
|
" int blocksparse_head_sliding_step) -> ()");
|
||||||
|
ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);
|
||||||
|
|
||||||
|
// Activation ops
|
||||||
|
|
||||||
|
// Activation function used in SwiGLU.
|
||||||
|
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
|
||||||
|
ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul);
|
||||||
|
|
||||||
|
// Activation function used in GeGLU with `none` approximation.
|
||||||
|
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
|
||||||
|
ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul);
|
||||||
|
|
||||||
|
// Activation function used in GeGLU with `tanh` approximation.
|
||||||
|
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
|
||||||
|
ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul);
|
||||||
|
|
||||||
|
// GELU implementation used in GPT-2.
|
||||||
|
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
|
||||||
|
ops.impl("gelu_new", torch::kCPU, &gelu_new);
|
||||||
|
|
||||||
|
// Approximate GELU implementation.
|
||||||
|
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
|
||||||
|
ops.impl("gelu_fast", torch::kCPU, &gelu_fast);
|
||||||
|
|
||||||
|
// Layernorm
|
||||||
|
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||||
|
ops.def(
|
||||||
|
"rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
|
||||||
|
"()");
|
||||||
|
ops.impl("rms_norm", torch::kCPU, &rms_norm);
|
||||||
|
|
||||||
|
// In-place fused Add and RMS Normalization.
|
||||||
|
ops.def(
|
||||||
|
"fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
|
||||||
|
"float epsilon) -> ()");
|
||||||
|
ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm);
|
||||||
|
|
||||||
|
// Rotary embedding
|
||||||
|
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
|
||||||
|
ops.def(
|
||||||
|
"rotary_embedding(Tensor positions, Tensor! query,"
|
||||||
|
" Tensor! key, int head_size,"
|
||||||
|
" Tensor cos_sin_cache, bool is_neox) -> ()");
|
||||||
|
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||||
|
// Cache ops
|
||||||
|
// Swap in (out) the cache blocks from src to dst.
|
||||||
|
cache_ops.def(
|
||||||
|
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
|
||||||
|
cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks);
|
||||||
|
|
||||||
|
// Copy the cache blocks from src to dst.
|
||||||
|
cache_ops.def(
|
||||||
|
"copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
|
||||||
|
"block_mapping) -> ()");
|
||||||
|
cache_ops.impl("copy_blocks", torch::kCPU, ©_blocks);
|
||||||
|
|
||||||
|
// Reshape the key and value tensors and cache them.
|
||||||
|
cache_ops.def(
|
||||||
|
"reshape_and_cache(Tensor key, Tensor value,"
|
||||||
|
" Tensor! key_cache, Tensor! value_cache,"
|
||||||
|
" Tensor slot_mapping,"
|
||||||
|
" str kv_cache_dtype,"
|
||||||
|
" float kv_scale) -> ()");
|
||||||
|
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||||
@@ -17,9 +17,14 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
|
||||||
|
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
||||||
|
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
|
||||||
|
__shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
|
||||||
#else
|
#else
|
||||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
||||||
|
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
|
||||||
|
__shfl_xor(var, lane_mask, width)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
@@ -28,6 +33,13 @@
|
|||||||
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
|
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
|
||||||
|
__shfl_down_sync(uint32_t(-1), var, lane_delta)
|
||||||
|
#else
|
||||||
|
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||||
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||||
@@ -35,4 +47,3 @@
|
|||||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||||
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,5 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/extension.h>
|
int64_t get_device_attribute(int64_t attribute, int64_t device_id);
|
||||||
|
|
||||||
int get_device_attribute(
|
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);
|
||||||
int attribute,
|
|
||||||
int device_id);
|
|
||||||
|
|
||||||
int get_max_shared_memory_per_block_device_attribute(
|
|
||||||
int device_id);
|
|
||||||
|
|||||||
@@ -2,26 +2,20 @@
|
|||||||
#include <hip/hip_runtime.h>
|
#include <hip/hip_runtime.h>
|
||||||
#include <hip/hip_runtime_api.h>
|
#include <hip/hip_runtime_api.h>
|
||||||
#endif
|
#endif
|
||||||
int get_device_attribute(
|
int64_t get_device_attribute(int64_t attribute, int64_t device_id) {
|
||||||
int attribute,
|
|
||||||
int device_id)
|
|
||||||
{
|
|
||||||
int device, value;
|
int device, value;
|
||||||
if (device_id < 0) {
|
if (device_id < 0) {
|
||||||
cudaGetDevice(&device);
|
cudaGetDevice(&device);
|
||||||
}
|
} else {
|
||||||
else {
|
|
||||||
device = device_id;
|
device = device_id;
|
||||||
}
|
}
|
||||||
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
|
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute),
|
||||||
|
device);
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id) {
|
||||||
int get_max_shared_memory_per_block_device_attribute(
|
int64_t attribute;
|
||||||
int device_id)
|
|
||||||
{
|
|
||||||
int attribute;
|
|
||||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
||||||
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
||||||
|
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
#include <ATen/cuda/Exceptions.h>
|
#include <ATen/cuda/Exceptions.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include <c10/cuda/CUDAStream.h>
|
#include <c10/cuda/CUDAStream.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
#include "custom_all_reduce.cuh"
|
#include "custom_all_reduce.cuh"
|
||||||
|
|
||||||
// fake pointer type
|
// fake pointer type, must match fptr_t type in ops.h
|
||||||
using fptr_t = uint64_t;
|
using fptr_t = int64_t;
|
||||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||||
|
|
||||||
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
||||||
const std::vector<std::string>& handles,
|
const std::vector<std::string>& handles,
|
||||||
const std::vector<int64_t> &offsets, int rank,
|
const std::vector<int64_t>& offsets, int64_t rank,
|
||||||
bool full_nvlink) {
|
bool full_nvlink) {
|
||||||
int world_size = offsets.size();
|
int world_size = offsets.size();
|
||||||
if (world_size > 8)
|
if (world_size > 8)
|
||||||
@@ -55,7 +55,7 @@ bool _is_weak_contiguous(torch::Tensor &t) {
|
|||||||
t.numel() * t.element_size());
|
t.numel() * t.element_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
|
||||||
bool full_nvlink) {
|
bool full_nvlink) {
|
||||||
auto inp_size = inp.numel() * inp.element_size();
|
auto inp_size = inp.numel() * inp.element_size();
|
||||||
// custom allreduce requires input byte size to be multiples of 16
|
// custom allreduce requires input byte size to be multiples of 16
|
||||||
@@ -80,8 +80,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
|
|||||||
}
|
}
|
||||||
case at::ScalarType::Half: {
|
case at::ScalarType::Half: {
|
||||||
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
|
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
|
||||||
reinterpret_cast<half *>(out.data_ptr()),
|
reinterpret_cast<half*>(out.data_ptr()), out.numel());
|
||||||
out.numel());
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||||
@@ -126,7 +125,7 @@ void dispose(fptr_t _fa) {
|
|||||||
delete fa;
|
delete fa;
|
||||||
}
|
}
|
||||||
|
|
||||||
int meta_size() { return sizeof(vllm::Signal); }
|
int64_t meta_size() { return sizeof(vllm::Signal); }
|
||||||
|
|
||||||
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
||||||
const std::vector<std::string>& handles,
|
const std::vector<std::string>& handles,
|
||||||
@@ -135,10 +134,16 @@ void register_buffer(fptr_t _fa, torch::Tensor &t,
|
|||||||
fa->register_buffer(handles, offsets, t.data_ptr());
|
fa->register_buffer(handles, offsets, t.data_ptr());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||||
fptr_t _fa) {
|
fptr_t _fa) {
|
||||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||||
return fa->get_graph_buffer_ipc_meta();
|
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
|
||||||
|
auto options =
|
||||||
|
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||||
|
auto handles =
|
||||||
|
torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
|
||||||
|
std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
|
||||||
|
return {handles, std::move(offsets)};
|
||||||
}
|
}
|
||||||
|
|
||||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
||||||
|
|||||||
@@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
|
|||||||
// Latency = 1 p2p write
|
// Latency = 1 p2p write
|
||||||
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
|
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
|
||||||
// wait until we got true from all ranks
|
// wait until we got true from all ranks
|
||||||
while (!self_sg->start[blockIdx.x][threadIdx.x])
|
while (!self_sg->start[blockIdx.x][threadIdx.x]);
|
||||||
;
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
@@ -162,8 +161,7 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
|
|||||||
// Latency = 1 p2p write
|
// Latency = 1 p2p write
|
||||||
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
|
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
|
||||||
// wait until we got true from all ranks
|
// wait until we got true from all ranks
|
||||||
while (!self_sg->end[blockIdx.x][threadIdx.x])
|
while (!self_sg->end[blockIdx.x][threadIdx.x]);
|
||||||
;
|
|
||||||
}
|
}
|
||||||
if constexpr (!final_sync) __syncthreads();
|
if constexpr (!final_sync) __syncthreads();
|
||||||
}
|
}
|
||||||
@@ -192,8 +190,7 @@ __global__ void __launch_bounds__(512, 1)
|
|||||||
// do the actual reduction
|
// do the actual reduction
|
||||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||||
idx += gridDim.x * blockDim.x) {
|
idx += gridDim.x * blockDim.x) {
|
||||||
((P *)result)[idx] =
|
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
|
||||||
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
|
|
||||||
}
|
}
|
||||||
end_sync<ngpus, true>(sg, self_sg, rank);
|
end_sync<ngpus, true>(sg, self_sg, rank);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||||
@@ -12,8 +12,7 @@
|
|||||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||||
|
|
||||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||||
AT_DISPATCH_SWITCH( \
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
|
||||||
|
|
||||||
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||||
@@ -22,8 +21,8 @@
|
|||||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
||||||
|
|
||||||
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
||||||
AT_DISPATCH_SWITCH( \
|
AT_DISPATCH_SWITCH(TYPE, NAME, \
|
||||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
||||||
@@ -33,5 +32,4 @@
|
|||||||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
||||||
|
|
||||||
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
||||||
AT_DISPATCH_SWITCH( \
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||||
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
@@ -23,9 +23,7 @@ __global__ void rms_norm_kernel(
|
|||||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
const float epsilon,
|
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||||
const int num_tokens,
|
|
||||||
const int hidden_size) {
|
|
||||||
__shared__ float s_variance;
|
__shared__ float s_variance;
|
||||||
float variance = 0.0f;
|
float variance = 0.0f;
|
||||||
|
|
||||||
@@ -41,11 +39,11 @@ __global__ void rms_norm_kernel(
|
|||||||
|
|
||||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||||
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
out[blockIdx.x * hidden_size + idx] =
|
||||||
|
((scalar_t)(x * s_variance)) * weight[idx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/* Converter structs for the conversion from torch types to HIP/CUDA types,
|
/* Converter structs for the conversion from torch types to HIP/CUDA types,
|
||||||
and the associated type conversions within HIP/CUDA. These helpers need
|
and the associated type conversions within HIP/CUDA. These helpers need
|
||||||
to be implemented for now because the relevant type conversion
|
to be implemented for now because the relevant type conversion
|
||||||
@@ -57,7 +55,9 @@ __global__ void rms_norm_kernel(
|
|||||||
If true, the struct should be fully defined as shown in the examples below.
|
If true, the struct should be fully defined as shown in the examples below.
|
||||||
*/
|
*/
|
||||||
template <typename torch_type>
|
template <typename torch_type>
|
||||||
struct _typeConvert { static constexpr bool exists = false; };
|
struct _typeConvert {
|
||||||
|
static constexpr bool exists = false;
|
||||||
|
};
|
||||||
|
|
||||||
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||||||
// CUDA < 12.0 runs into issues with packed type conversion
|
// CUDA < 12.0 runs into issues with packed type conversion
|
||||||
@@ -68,9 +68,15 @@ struct _typeConvert<c10::Half> {
|
|||||||
using packed_hip_type = __half2;
|
using packed_hip_type = __half2;
|
||||||
|
|
||||||
__device__ static inline float convert(hip_type x) { return __half2float(x); }
|
__device__ static inline float convert(hip_type x) { return __half2float(x); }
|
||||||
__device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); }
|
__device__ static inline float2 convert(packed_hip_type x) {
|
||||||
__device__ static inline hip_type convert(float x) { return __float2half_rn(x); }
|
return __half22float2(x);
|
||||||
__device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); }
|
}
|
||||||
|
__device__ static inline hip_type convert(float x) {
|
||||||
|
return __float2half_rn(x);
|
||||||
|
}
|
||||||
|
__device__ static inline packed_hip_type convert(float2 x) {
|
||||||
|
return __float22half2_rn(x);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
@@ -82,13 +88,22 @@ struct _typeConvert<c10::BFloat16> {
|
|||||||
using hip_type = __nv_bfloat16;
|
using hip_type = __nv_bfloat16;
|
||||||
using packed_hip_type = __nv_bfloat162;
|
using packed_hip_type = __nv_bfloat162;
|
||||||
|
|
||||||
__device__ static inline float convert(hip_type x) { return __bfloat162float(x); }
|
__device__ static inline float convert(hip_type x) {
|
||||||
__device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); }
|
return __bfloat162float(x);
|
||||||
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); }
|
}
|
||||||
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); }
|
__device__ static inline float2 convert(packed_hip_type x) {
|
||||||
|
return __bfloat1622float2(x);
|
||||||
|
}
|
||||||
|
__device__ static inline hip_type convert(float x) {
|
||||||
|
return __float2bfloat16(x);
|
||||||
|
}
|
||||||
|
__device__ static inline packed_hip_type convert(float2 x) {
|
||||||
|
return __float22bfloat162_rn(x);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
|
||||||
|
// 12000))
|
||||||
|
|
||||||
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
|
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
|
||||||
for appropriate specializations of fused_add_rms_norm_kernel.
|
for appropriate specializations of fused_add_rms_norm_kernel.
|
||||||
@@ -117,8 +132,7 @@ struct alignas(16) _f16Vec {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < width; ++i)
|
for (int i = 0; i < width; ++i) data[i] += other.data[i];
|
||||||
data[i] += other.data[i];
|
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@@ -134,8 +148,7 @@ struct alignas(16) _f16Vec {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < width; ++i)
|
for (int i = 0; i < width; ++i) data[i] *= other.data[i];
|
||||||
data[i] *= other.data[i];
|
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@@ -185,14 +198,12 @@ struct alignas(16) _f16Vec {
|
|||||||
packed and vectorized operations, which help with the
|
packed and vectorized operations, which help with the
|
||||||
memory latency bottleneck. */
|
memory latency bottleneck. */
|
||||||
template <typename scalar_t, int width>
|
template <typename scalar_t, int width>
|
||||||
__global__ std::enable_if_t<
|
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
||||||
(width > 0) && _typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
|
fused_add_rms_norm_kernel(
|
||||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
const float epsilon,
|
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||||
const int num_tokens,
|
|
||||||
const int hidden_size) {
|
|
||||||
// Sanity checks on our vector struct and type-punned pointer arithmetic
|
// Sanity checks on our vector struct and type-punned pointer arithmetic
|
||||||
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
|
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
|
||||||
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
|
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
|
||||||
@@ -203,9 +214,12 @@ __global__ std::enable_if_t<
|
|||||||
/* These and the argument pointers are all declared `restrict` as they are
|
/* These and the argument pointers are all declared `restrict` as they are
|
||||||
not aliased in practice. Argument pointers should not be dereferenced
|
not aliased in practice. Argument pointers should not be dereferenced
|
||||||
in this kernel as that would be undefined behavior */
|
in this kernel as that would be undefined behavior */
|
||||||
auto* __restrict__ input_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
|
auto* __restrict__ input_v =
|
||||||
auto* __restrict__ residual_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
|
reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
|
||||||
auto* __restrict__ weight_v = reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
|
auto* __restrict__ residual_v =
|
||||||
|
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
|
||||||
|
auto* __restrict__ weight_v =
|
||||||
|
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
|
||||||
|
|
||||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||||
int id = blockIdx.x * vec_hidden_size + idx;
|
int id = blockIdx.x * vec_hidden_size + idx;
|
||||||
@@ -218,7 +232,8 @@ __global__ std::enable_if_t<
|
|||||||
calculation of max_block_size in fused_add_rms_norm */
|
calculation of max_block_size in fused_add_rms_norm */
|
||||||
if (num_tokens < 256) {
|
if (num_tokens < 256) {
|
||||||
variance = blockReduceSum<float, 1024>(variance);
|
variance = blockReduceSum<float, 1024>(variance);
|
||||||
} else variance = blockReduceSum<float, 256>(variance);
|
} else
|
||||||
|
variance = blockReduceSum<float, 256>(variance);
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
}
|
}
|
||||||
@@ -233,19 +248,16 @@ __global__ std::enable_if_t<
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/* Generic fused_add_rms_norm_kernel
|
/* Generic fused_add_rms_norm_kernel
|
||||||
The width field is not used here but necessary for other specializations.
|
The width field is not used here but necessary for other specializations.
|
||||||
*/
|
*/
|
||||||
template <typename scalar_t, int width>
|
template <typename scalar_t, int width>
|
||||||
__global__ std::enable_if_t<
|
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
||||||
(width == 0) || !_typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
|
fused_add_rms_norm_kernel(
|
||||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
const float epsilon,
|
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||||
const int num_tokens,
|
|
||||||
const int hidden_size) {
|
|
||||||
__shared__ float s_variance;
|
__shared__ float s_variance;
|
||||||
float variance = 0.0f;
|
float variance = 0.0f;
|
||||||
|
|
||||||
@@ -260,7 +272,8 @@ __global__ std::enable_if_t<
|
|||||||
calculation of max_block_size in fused_add_rms_norm */
|
calculation of max_block_size in fused_add_rms_norm */
|
||||||
if (num_tokens < 256) {
|
if (num_tokens < 256) {
|
||||||
variance = blockReduceSum<float, 1024>(variance);
|
variance = blockReduceSum<float, 1024>(variance);
|
||||||
} else variance = blockReduceSum<float, 256>(variance);
|
} else
|
||||||
|
variance = blockReduceSum<float, 256>(variance);
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
}
|
}
|
||||||
@@ -268,17 +281,17 @@ __global__ std::enable_if_t<
|
|||||||
|
|
||||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
float x = (float)residual[blockIdx.x * hidden_size + idx];
|
float x = (float)residual[blockIdx.x * hidden_size + idx];
|
||||||
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
input[blockIdx.x * hidden_size + idx] =
|
||||||
|
((scalar_t)(x * s_variance)) * weight[idx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void rms_norm(
|
void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
||||||
torch::Tensor& out, // [..., hidden_size]
|
|
||||||
torch::Tensor& input, // [..., hidden_size]
|
torch::Tensor& input, // [..., hidden_size]
|
||||||
torch::Tensor& weight, // [hidden_size]
|
torch::Tensor& weight, // [hidden_size]
|
||||||
float epsilon) {
|
double epsilon) {
|
||||||
int hidden_size = input.size(-1);
|
int hidden_size = input.size(-1);
|
||||||
int num_tokens = input.numel() / hidden_size;
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
@@ -286,40 +299,27 @@ void rms_norm(
|
|||||||
dim3 block(std::min(hidden_size, 1024));
|
dim3 block(std::min(hidden_size, 1024));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
|
||||||
input.scalar_type(),
|
|
||||||
"rms_norm_kernel",
|
|
||||||
[&] {
|
|
||||||
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
out.data_ptr<scalar_t>(),
|
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
||||||
input.data_ptr<scalar_t>(),
|
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
|
||||||
weight.data_ptr<scalar_t>(),
|
|
||||||
epsilon,
|
|
||||||
num_tokens,
|
|
||||||
hidden_size);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||||
input.scalar_type(), \
|
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
|
||||||
"fused_add_rms_norm_kernel", \
|
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
|
||||||
[&] { \
|
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
|
||||||
vllm::fused_add_rms_norm_kernel \
|
|
||||||
<scalar_t, width><<<grid, block, 0, stream>>>( \
|
|
||||||
input.data_ptr<scalar_t>(), \
|
|
||||||
residual.data_ptr<scalar_t>(), \
|
residual.data_ptr<scalar_t>(), \
|
||||||
weight.data_ptr<scalar_t>(), \
|
weight.data_ptr<scalar_t>(), epsilon, \
|
||||||
epsilon, \
|
num_tokens, hidden_size); \
|
||||||
num_tokens, \
|
|
||||||
hidden_size); \
|
|
||||||
});
|
});
|
||||||
|
|
||||||
void fused_add_rms_norm(
|
void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
||||||
torch::Tensor& input, // [..., hidden_size]
|
|
||||||
torch::Tensor& residual, // [..., hidden_size]
|
torch::Tensor& residual, // [..., hidden_size]
|
||||||
torch::Tensor& weight, // [hidden_size]
|
torch::Tensor& weight, // [hidden_size]
|
||||||
float epsilon) {
|
double epsilon) {
|
||||||
int hidden_size = input.size(-1);
|
int hidden_size = input.size(-1);
|
||||||
int num_tokens = input.numel() / hidden_size;
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
@@ -342,8 +342,8 @@ void fused_add_rms_norm(
|
|||||||
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||||
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
|
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
|
||||||
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||||
bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \
|
bool ptrs_are_aligned =
|
||||||
&& wt_ptr % 16 == 0;
|
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
||||||
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
||||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
#include "moe_ops.h"
|
|
||||||
|
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs.");
|
|
||||||
}
|
|
||||||
@@ -1,9 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
void topk_softmax(
|
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
|
||||||
torch::Tensor& topk_weights,
|
|
||||||
torch::Tensor& topk_indices,
|
|
||||||
torch::Tensor& token_expert_indices,
|
torch::Tensor& token_expert_indices,
|
||||||
torch::Tensor& gating_output);
|
torch::Tensor& gating_output);
|
||||||
|
|||||||
@@ -16,18 +16,25 @@
|
|||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include "../cuda_compat.h"
|
||||||
|
|
||||||
#include <cub/cub.cuh>
|
#ifndef USE_ROCM
|
||||||
#include <cub/util_type.cuh>
|
#include <cub/util_type.cuh>
|
||||||
|
#include <cub/cub.cuh>
|
||||||
|
#else
|
||||||
|
#include <hipcub/util_type.hpp>
|
||||||
|
#include <hipcub/hipcub.hpp>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||||
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
namespace moe {
|
namespace moe {
|
||||||
|
|
||||||
static constexpr int WARP_SIZE = 32;
|
|
||||||
|
|
||||||
/// Aligned array type
|
/// Aligned array type
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
@@ -265,7 +272,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||||
{
|
{
|
||||||
thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
|
thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW));
|
||||||
}
|
}
|
||||||
|
|
||||||
// From this point, thread max in all the threads have the max within the row.
|
// From this point, thread max in all the threads have the max within the row.
|
||||||
@@ -282,7 +289,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||||
{
|
{
|
||||||
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
|
row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW);
|
||||||
}
|
}
|
||||||
|
|
||||||
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
|
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
|
||||||
@@ -332,8 +339,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||||
{
|
{
|
||||||
float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
|
float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
|
||||||
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
|
int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);
|
||||||
|
|
||||||
// We want lower indices to "win" in every thread so we break ties this way
|
// We want lower indices to "win" in every thread so we break ties this way
|
||||||
if (other_max > max_val || (other_max == max_val && other_expert < expert))
|
if (other_max > max_val || (other_max == max_val && other_expert < expert))
|
||||||
@@ -383,7 +390,7 @@ struct TopkConstants
|
|||||||
{
|
{
|
||||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
||||||
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
|
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
|
||||||
static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
|
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
|
||||||
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
||||||
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
|
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
|
||||||
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
|
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
|
||||||
@@ -396,7 +403,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
|
|||||||
{
|
{
|
||||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||||
|
|
||||||
static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
||||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
|
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
|
||||||
static constexpr int VPT = Constants::VPT;
|
static constexpr int VPT = Constants::VPT;
|
||||||
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
||||||
|
|||||||
12
csrc/moe/torch_bindings.cpp
Normal file
12
csrc/moe/torch_bindings.cpp
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
#include "registration.h"
|
||||||
|
#include "moe_ops.h"
|
||||||
|
|
||||||
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||||
|
// Apply topk softmax to the gating outputs.
|
||||||
|
m.def(
|
||||||
|
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
|
||||||
|
"token_expert_indices, Tensor gating_output) -> ()");
|
||||||
|
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
@@ -12,11 +12,12 @@
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) {
|
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
|
||||||
|
int32_t col) {
|
||||||
// don't worry about overflow because num_experts is relatively small
|
// don't worry about overflow because num_experts is relatively small
|
||||||
return row * total_col + col;
|
return row * total_col + col;
|
||||||
}
|
}
|
||||||
}
|
} // namespace
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
||||||
@@ -24,15 +25,17 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
|||||||
int32_t* expert_ids,
|
int32_t* expert_ids,
|
||||||
int32_t* total_tokens_post_pad,
|
int32_t* total_tokens_post_pad,
|
||||||
int32_t num_experts,
|
int32_t num_experts,
|
||||||
int32_t block_size,
|
int32_t block_size, size_t numel) {
|
||||||
size_t numel) {
|
|
||||||
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
||||||
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||||
|
|
||||||
extern __shared__ int32_t shared_mem[];
|
extern __shared__ int32_t shared_mem[];
|
||||||
|
|
||||||
int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
|
int32_t* tokens_cnts =
|
||||||
int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
|
shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
|
||||||
|
int32_t* cumsum =
|
||||||
|
shared_mem + (num_experts + 1) *
|
||||||
|
num_experts; // 1d tensor with shape (num_experts + 1)
|
||||||
|
|
||||||
for (int i = 0; i < num_experts; ++i) {
|
for (int i = 0; i < num_experts; ++i) {
|
||||||
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
|
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
|
||||||
@@ -40,8 +43,8 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* In the first step we compute token_cnts[thread_index + 1][expert_index],
|
* In the first step we compute token_cnts[thread_index + 1][expert_index],
|
||||||
* which counts how many tokens in the token shard of thread_index are assigned
|
* which counts how many tokens in the token shard of thread_index are
|
||||||
* to expert expert_index.
|
* assigned to expert expert_index.
|
||||||
*/
|
*/
|
||||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||||
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
|
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
|
||||||
@@ -52,7 +55,8 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
|||||||
// For each expert we accumulate the token counts from the different threads.
|
// For each expert we accumulate the token counts from the different threads.
|
||||||
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
|
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
|
||||||
for (int i = 1; i <= blockDim.x; ++i) {
|
for (int i = 1; i <= blockDim.x; ++i) {
|
||||||
tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)];
|
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
|
||||||
|
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
@@ -61,7 +65,10 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
|||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
cumsum[0] = 0;
|
cumsum[0] = 0;
|
||||||
for (int i = 1; i <= num_experts; ++i) {
|
for (int i = 1; i <= num_experts; ++i) {
|
||||||
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size;
|
cumsum[i] = cumsum[i - 1] +
|
||||||
|
CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
|
||||||
|
block_size) *
|
||||||
|
block_size;
|
||||||
}
|
}
|
||||||
*total_tokens_post_pad = cumsum[num_experts];
|
*total_tokens_post_pad = cumsum[num_experts];
|
||||||
}
|
}
|
||||||
@@ -69,57 +76,59 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* For each expert, each thread processes the tokens of the corresponding blocks
|
* For each expert, each thread processes the tokens of the corresponding
|
||||||
* and stores the corresponding expert_id for each block.
|
* blocks and stores the corresponding expert_id for each block.
|
||||||
*/
|
*/
|
||||||
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
|
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
|
||||||
|
i += block_size) {
|
||||||
expert_ids[i / block_size] = threadIdx.x;
|
expert_ids[i / block_size] = threadIdx.x;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Each thread processes a token shard, calculating the index of each token after
|
* Each thread processes a token shard, calculating the index of each token
|
||||||
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
|
* after sorting by expert number. Given the example topk_ids =
|
||||||
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
|
* [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
|
||||||
* where * represents a padding value(preset in python).
|
* *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
|
||||||
|
* padding value(preset in python).
|
||||||
*/
|
*/
|
||||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||||
int32_t expert_id = topk_ids[i];
|
int32_t expert_id = topk_ids[i];
|
||||||
/** The cumsum[expert_id] stores the starting index of the tokens that the
|
/** The cumsum[expert_id] stores the starting index of the tokens that the
|
||||||
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
|
* expert with expert_id needs to process, and
|
||||||
* stores the indices of the tokens processed by the expert with expert_id within
|
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
|
||||||
* the current thread's token shard.
|
* processed by the expert with expert_id within the current thread's token
|
||||||
|
* shard.
|
||||||
*/
|
*/
|
||||||
int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id];
|
int32_t rank_post_pad =
|
||||||
|
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
|
||||||
|
cumsum[expert_id];
|
||||||
sorted_token_ids[rank_post_pad] = i;
|
sorted_token_ids[rank_post_pad] = i;
|
||||||
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
|
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
} // namespace vllm
|
||||||
|
|
||||||
void moe_align_block_size(
|
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||||
torch::Tensor topk_ids,
|
int64_t block_size, torch::Tensor sorted_token_ids,
|
||||||
int num_experts,
|
|
||||||
int block_size,
|
|
||||||
torch::Tensor sorted_token_ids,
|
|
||||||
torch::Tensor experts_ids,
|
torch::Tensor experts_ids,
|
||||||
torch::Tensor num_tokens_post_pad) {
|
torch::Tensor num_tokens_post_pad) {
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors
|
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||||
const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
|
// tensors
|
||||||
|
const int32_t shared_mem =
|
||||||
|
((num_experts + 1) * num_experts + (num_experts + 1)) *
|
||||||
|
sizeof(int32_t);
|
||||||
|
|
||||||
// set dynamic shared mem
|
// set dynamic shared mem
|
||||||
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
|
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
|
||||||
AT_CUDA_CHECK(
|
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
|
||||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem));
|
(void*)kernel, shared_mem));
|
||||||
kernel<<<1, num_experts, shared_mem, stream>>>(
|
kernel<<<1, num_experts, shared_mem, stream>>>(
|
||||||
topk_ids.data_ptr<scalar_t>(),
|
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
||||||
sorted_token_ids.data_ptr<int32_t>(),
|
|
||||||
experts_ids.data_ptr<int32_t>(),
|
experts_ids.data_ptr<int32_t>(),
|
||||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||||
num_experts,
|
|
||||||
block_size,
|
|
||||||
topk_ids.numel());
|
topk_ids.numel());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
230
csrc/ops.h
230
csrc/ops.h
@@ -1,206 +1,146 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/library.h>
|
||||||
|
|
||||||
void paged_attention_v1(
|
void paged_attention_v1(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& query,
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||||
torch::Tensor& key_cache,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||||
torch::Tensor& value_cache,
|
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
int num_kv_heads,
|
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
||||||
float scale,
|
const int64_t blocksparse_local_blocks,
|
||||||
torch::Tensor& block_tables,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
torch::Tensor& seq_lens,
|
const int64_t blocksparse_head_sliding_step);
|
||||||
int block_size,
|
|
||||||
int max_seq_len,
|
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
|
||||||
const std::string& kv_cache_dtype,
|
|
||||||
float kv_scale);
|
|
||||||
|
|
||||||
void paged_attention_v2(
|
void paged_attention_v2(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||||
torch::Tensor& exp_sums,
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& max_logits,
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||||
torch::Tensor& tmp_out,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||||
torch::Tensor& query,
|
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
torch::Tensor& key_cache,
|
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
||||||
torch::Tensor& value_cache,
|
const int64_t blocksparse_local_blocks,
|
||||||
int num_kv_heads,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
float scale,
|
const int64_t blocksparse_head_sliding_step);
|
||||||
torch::Tensor& block_tables,
|
|
||||||
torch::Tensor& seq_lens,
|
|
||||||
int block_size,
|
|
||||||
int max_seq_len,
|
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
|
||||||
const std::string& kv_cache_dtype,
|
|
||||||
float kv_scale);
|
|
||||||
|
|
||||||
void rms_norm(
|
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||||
torch::Tensor& out,
|
double epsilon);
|
||||||
torch::Tensor& input,
|
|
||||||
torch::Tensor& weight,
|
|
||||||
float epsilon);
|
|
||||||
|
|
||||||
void fused_add_rms_norm(
|
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||||
torch::Tensor& input,
|
torch::Tensor& weight, double epsilon);
|
||||||
torch::Tensor& residual,
|
|
||||||
torch::Tensor& weight,
|
|
||||||
float epsilon);
|
|
||||||
|
|
||||||
void rotary_embedding(
|
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||||
torch::Tensor& positions,
|
torch::Tensor& key, int64_t head_size,
|
||||||
torch::Tensor& query,
|
torch::Tensor& cos_sin_cache, bool is_neox);
|
||||||
torch::Tensor& key,
|
|
||||||
int head_size,
|
|
||||||
torch::Tensor& cos_sin_cache,
|
|
||||||
bool is_neox);
|
|
||||||
|
|
||||||
void batched_rotary_embedding(
|
void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||||
torch::Tensor& positions,
|
torch::Tensor& key, int64_t head_size,
|
||||||
torch::Tensor& query,
|
torch::Tensor& cos_sin_cache, bool is_neox,
|
||||||
torch::Tensor& key,
|
int64_t rot_dim,
|
||||||
int head_size,
|
|
||||||
torch::Tensor& cos_sin_cache,
|
|
||||||
bool is_neox,
|
|
||||||
int rot_dim,
|
|
||||||
torch::Tensor& cos_sin_cache_offsets);
|
torch::Tensor& cos_sin_cache_offsets);
|
||||||
|
|
||||||
void silu_and_mul(
|
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input);
|
|
||||||
|
|
||||||
void gelu_and_mul(
|
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input);
|
|
||||||
|
|
||||||
void gelu_tanh_and_mul(
|
void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input);
|
|
||||||
|
|
||||||
void gelu_new(
|
void gelu_new(torch::Tensor& out, torch::Tensor& input);
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input);
|
|
||||||
|
|
||||||
void gelu_fast(
|
void gelu_fast(torch::Tensor& out, torch::Tensor& input);
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input);
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
torch::Tensor aqlm_gemm(
|
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||||
const torch::Tensor& input,
|
|
||||||
const torch::Tensor& codes,
|
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& scales,
|
const torch::Tensor& scales,
|
||||||
const torch::Tensor& codebook_partition_sizes,
|
const torch::Tensor& codebook_partition_sizes,
|
||||||
const std::optional<torch::Tensor>& bias
|
const std::optional<torch::Tensor>& bias);
|
||||||
);
|
|
||||||
|
|
||||||
torch::Tensor aqlm_dequant(
|
torch::Tensor aqlm_dequant(const torch::Tensor& codes,
|
||||||
const torch::Tensor& codes,
|
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& codebook_partition_sizes
|
const torch::Tensor& codebook_partition_sizes);
|
||||||
);
|
|
||||||
|
|
||||||
torch::Tensor awq_gemm(
|
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
||||||
torch::Tensor _in_feats,
|
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
||||||
torch::Tensor _kernel,
|
int64_t split_k_iters);
|
||||||
|
|
||||||
|
torch::Tensor awq_dequantize(torch::Tensor _kernel,
|
||||||
torch::Tensor _scaling_factors,
|
torch::Tensor _scaling_factors,
|
||||||
torch::Tensor _zeros,
|
torch::Tensor _zeros, int64_t split_k_iters,
|
||||||
int split_k_iters);
|
int64_t thx, int64_t thy);
|
||||||
|
|
||||||
torch::Tensor awq_dequantize(
|
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor _kernel,
|
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||||
torch::Tensor _scaling_factors,
|
int64_t size_m, int64_t size_n, int64_t size_k);
|
||||||
torch::Tensor _zeros,
|
|
||||||
int split_k_iters,
|
|
||||||
int thx,
|
|
||||||
int thy);
|
|
||||||
|
|
||||||
torch::Tensor marlin_gemm(
|
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor& a,
|
torch::Tensor& b_meta,
|
||||||
torch::Tensor& b_q_weight,
|
|
||||||
torch::Tensor& b_scales,
|
torch::Tensor& b_scales,
|
||||||
torch::Tensor& workspace,
|
torch::Tensor& workspace, int64_t num_bits,
|
||||||
int64_t size_m,
|
int64_t size_m, int64_t size_n,
|
||||||
int64_t size_n,
|
|
||||||
int64_t size_k);
|
int64_t size_k);
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_gemm(
|
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
torch::Tensor &a,
|
torch::Tensor& b_scales, torch::Tensor& g_idx,
|
||||||
torch::Tensor &b_q_weight,
|
torch::Tensor& perm, torch::Tensor& workspace,
|
||||||
torch::Tensor &b_scales,
|
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||||
torch::Tensor &g_idx,
|
int64_t size_k, bool is_k_full);
|
||||||
torch::Tensor &perm,
|
|
||||||
torch::Tensor &workspace,
|
|
||||||
int64_t num_bits,
|
|
||||||
int64_t size_m,
|
|
||||||
int64_t size_n,
|
|
||||||
int64_t size_k,
|
|
||||||
bool is_k_full);
|
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_repack(
|
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||||
torch::Tensor &b_q_weight,
|
int64_t size_k, int64_t size_n,
|
||||||
torch::Tensor &perm,
|
|
||||||
int64_t size_k,
|
|
||||||
int64_t size_n,
|
|
||||||
int64_t num_bits);
|
int64_t num_bits);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void squeezellm_gemm(
|
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||||
torch::Tensor vec,
|
torch::Tensor const& scale);
|
||||||
torch::Tensor mat,
|
|
||||||
torch::Tensor mul,
|
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||||
|
torch::Tensor& scales);
|
||||||
|
|
||||||
|
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||||
torch::Tensor lookup_table);
|
torch::Tensor lookup_table);
|
||||||
|
|
||||||
torch::Tensor gptq_gemm(
|
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||||
torch::Tensor a,
|
|
||||||
torch::Tensor b_q_weight,
|
|
||||||
torch::Tensor b_gptq_qzeros,
|
torch::Tensor b_gptq_qzeros,
|
||||||
torch::Tensor b_gptq_scales,
|
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
|
||||||
torch::Tensor b_g_idx,
|
bool use_exllama, int64_t bit);
|
||||||
bool use_exllama,
|
|
||||||
int bit);
|
|
||||||
|
|
||||||
void gptq_shuffle(
|
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
|
||||||
torch::Tensor q_weight,
|
|
||||||
torch::Tensor q_perm,
|
|
||||||
int bit);
|
|
||||||
|
|
||||||
void static_scaled_fp8_quant(
|
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input,
|
|
||||||
torch::Tensor& scale);
|
torch::Tensor& scale);
|
||||||
|
|
||||||
void dynamic_scaled_fp8_quant(
|
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input,
|
|
||||||
torch::Tensor& scale);
|
torch::Tensor& scale);
|
||||||
|
|
||||||
void moe_align_block_size(
|
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||||
torch::Tensor topk_ids,
|
int64_t block_size, torch::Tensor sorted_token_ids,
|
||||||
int num_experts,
|
|
||||||
int block_size,
|
|
||||||
torch::Tensor sorted_token_ids,
|
|
||||||
torch::Tensor experts_ids,
|
torch::Tensor experts_ids,
|
||||||
torch::Tensor num_tokens_post_pad);
|
torch::Tensor num_tokens_post_pad);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
using fptr_t = uint64_t;
|
using fptr_t = int64_t;
|
||||||
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
||||||
const std::vector<std::string>& handles,
|
const std::vector<std::string>& handles,
|
||||||
const std::vector<int64_t> &offsets, int rank,
|
const std::vector<int64_t>& offsets, int64_t rank,
|
||||||
bool full_nvlink);
|
bool full_nvlink);
|
||||||
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
|
||||||
bool full_nvlink);
|
bool full_nvlink);
|
||||||
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
|
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
|
||||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
|
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
|
||||||
torch::Tensor& out);
|
torch::Tensor& out);
|
||||||
void dispose(fptr_t _fa);
|
void dispose(fptr_t _fa);
|
||||||
int meta_size();
|
int64_t meta_size();
|
||||||
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
||||||
const std::vector<std::string>& handles,
|
const std::vector<std::string>& handles,
|
||||||
const std::vector<int64_t>& offsets);
|
const std::vector<int64_t>& offsets);
|
||||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
|
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||||
|
fptr_t _fa);
|
||||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
||||||
const std::vector<std::vector<int64_t>>& offsets);
|
const std::vector<std::vector<int64_t>>& offsets);
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
@@ -9,12 +9,8 @@ namespace vllm {
|
|||||||
|
|
||||||
template <typename scalar_t, bool IS_NEOX>
|
template <typename scalar_t, bool IS_NEOX>
|
||||||
inline __device__ void apply_token_rotary_embedding(
|
inline __device__ void apply_token_rotary_embedding(
|
||||||
scalar_t* __restrict__ arr,
|
scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
|
||||||
const scalar_t* __restrict__ cos_ptr,
|
const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) {
|
||||||
const scalar_t* __restrict__ sin_ptr,
|
|
||||||
int rot_offset,
|
|
||||||
int embed_dim)
|
|
||||||
{
|
|
||||||
int x_index, y_index;
|
int x_index, y_index;
|
||||||
scalar_t cos, sin;
|
scalar_t cos, sin;
|
||||||
if (IS_NEOX) {
|
if (IS_NEOX) {
|
||||||
@@ -39,17 +35,15 @@ inline __device__ void apply_token_rotary_embedding(
|
|||||||
|
|
||||||
template <typename scalar_t, bool IS_NEOX>
|
template <typename scalar_t, bool IS_NEOX>
|
||||||
inline __device__ void apply_rotary_embedding(
|
inline __device__ void apply_rotary_embedding(
|
||||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
||||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
// head_size] or [num_tokens, num_heads,
|
||||||
const scalar_t* cache_ptr,
|
// head_size]
|
||||||
const int head_size,
|
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||||
const int num_heads,
|
// head_size] or [num_tokens, num_kv_heads,
|
||||||
const int num_kv_heads,
|
// head_size]
|
||||||
const int rot_dim,
|
const scalar_t* cache_ptr, const int head_size, const int num_heads,
|
||||||
const int token_idx,
|
const int num_kv_heads, const int rot_dim, const int token_idx,
|
||||||
const int64_t query_stride,
|
const int64_t query_stride, const int64_t key_stride) {
|
||||||
const int64_t key_stride)
|
|
||||||
{
|
|
||||||
const int embed_dim = rot_dim / 2;
|
const int embed_dim = rot_dim / 2;
|
||||||
const scalar_t* cos_ptr = cache_ptr;
|
const scalar_t* cos_ptr = cache_ptr;
|
||||||
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
||||||
@@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding(
|
|||||||
const int head_idx = i / embed_dim;
|
const int head_idx = i / embed_dim;
|
||||||
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
||||||
const int rot_offset = i % embed_dim;
|
const int rot_offset = i % embed_dim;
|
||||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
|
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
||||||
sin_ptr, rot_offset, embed_dim);
|
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int nk = num_kv_heads * embed_dim;
|
const int nk = num_kv_heads * embed_dim;
|
||||||
@@ -68,60 +62,72 @@ inline __device__ void apply_rotary_embedding(
|
|||||||
const int head_idx = i / embed_dim;
|
const int head_idx = i / embed_dim;
|
||||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||||
const int rot_offset = i % embed_dim;
|
const int rot_offset = i % embed_dim;
|
||||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
|
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
||||||
sin_ptr, rot_offset, embed_dim);
|
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, bool IS_NEOX>
|
template <typename scalar_t, bool IS_NEOX>
|
||||||
__global__ void rotary_embedding_kernel(
|
__global__ void rotary_embedding_kernel(
|
||||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
// [num_tokens]
|
||||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
||||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
// head_size] or [num_tokens, num_heads,
|
||||||
const int rot_dim,
|
// head_size]
|
||||||
const int64_t query_stride,
|
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||||
const int64_t key_stride,
|
// head_size] or [num_tokens, num_kv_heads,
|
||||||
const int num_heads,
|
// head_size]
|
||||||
const int num_kv_heads,
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||||
const int head_size) {
|
// 2]
|
||||||
|
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||||
|
const int num_heads, const int num_kv_heads, const int head_size) {
|
||||||
// Each thread block is responsible for one token.
|
// Each thread block is responsible for one token.
|
||||||
const int token_idx = blockIdx.x;
|
const int token_idx = blockIdx.x;
|
||||||
int64_t pos = positions[token_idx];
|
int64_t pos = positions[token_idx];
|
||||||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||||
|
|
||||||
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
|
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
||||||
|
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
||||||
|
token_idx, query_stride, key_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, bool IS_NEOX>
|
template <typename scalar_t, bool IS_NEOX>
|
||||||
__global__ void batched_rotary_embedding_kernel(
|
__global__ void batched_rotary_embedding_kernel(
|
||||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
// [num_tokens]
|
||||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
||||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
// head_size] or [num_tokens, num_heads,
|
||||||
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens]
|
// head_size]
|
||||||
const int rot_dim,
|
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||||
const int64_t query_stride,
|
// head_size] or [num_tokens, num_kv_heads,
|
||||||
const int64_t key_stride,
|
// head_size]
|
||||||
const int num_heads,
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||||
const int num_kv_heads,
|
// 2]
|
||||||
const int head_size) {
|
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
|
||||||
|
// or [num_tokens]
|
||||||
|
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||||
|
const int num_heads, const int num_kv_heads, const int head_size) {
|
||||||
// Each thread block is responsible for one token.
|
// Each thread block is responsible for one token.
|
||||||
const int token_idx = blockIdx.x;
|
const int token_idx = blockIdx.x;
|
||||||
int64_t pos = positions[token_idx];
|
int64_t pos = positions[token_idx];
|
||||||
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
|
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
|
||||||
const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
|
const scalar_t* cache_ptr =
|
||||||
|
cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
|
||||||
|
|
||||||
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
|
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
||||||
|
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
||||||
|
token_idx, query_stride, key_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void rotary_embedding(
|
void rotary_embedding(
|
||||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
|
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
|
||||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
|
// [num_tokens, num_heads * head_size]
|
||||||
int head_size,
|
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
||||||
|
// [num_tokens, num_kv_heads * head_size]
|
||||||
|
int64_t head_size,
|
||||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||||
bool is_neox) {
|
bool is_neox) {
|
||||||
int64_t num_tokens = query.numel() / query.size(-1);
|
int64_t num_tokens = query.numel() / query.size(-1);
|
||||||
@@ -132,36 +138,21 @@ void rotary_embedding(
|
|||||||
int64_t key_stride = key.stride(-2);
|
int64_t key_stride = key.stride(-2);
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
|
||||||
query.scalar_type(),
|
|
||||||
"rotary_embedding",
|
|
||||||
[&] {
|
|
||||||
if (is_neox) {
|
if (is_neox) {
|
||||||
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
||||||
positions.data_ptr<int64_t>(),
|
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||||
query.data_ptr<scalar_t>(),
|
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
|
||||||
key.data_ptr<scalar_t>(),
|
query_stride, key_stride, num_heads, num_kv_heads, head_size);
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
|
||||||
rot_dim,
|
|
||||||
query_stride,
|
|
||||||
key_stride,
|
|
||||||
num_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
head_size);
|
|
||||||
} else {
|
} else {
|
||||||
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
|
vllm::rotary_embedding_kernel<scalar_t, false>
|
||||||
positions.data_ptr<int64_t>(),
|
<<<grid, block, 0, stream>>>(
|
||||||
query.data_ptr<scalar_t>(),
|
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||||
key.data_ptr<scalar_t>(),
|
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
|
||||||
rot_dim,
|
|
||||||
query_stride,
|
|
||||||
key_stride,
|
|
||||||
num_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
head_size);
|
head_size);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -173,12 +164,13 @@ and process in batched manner.
|
|||||||
*/
|
*/
|
||||||
void batched_rotary_embedding(
|
void batched_rotary_embedding(
|
||||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
|
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
|
||||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
|
// [num_tokens, num_heads * head_size]
|
||||||
int head_size,
|
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
||||||
|
// [num_tokens, num_kv_heads * head_size]
|
||||||
|
int64_t head_size,
|
||||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||||
bool is_neox,
|
bool is_neox, int64_t rot_dim,
|
||||||
int rot_dim,
|
|
||||||
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
|
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
|
||||||
) {
|
) {
|
||||||
int64_t num_tokens = cos_sin_cache_offsets.size(0);
|
int64_t num_tokens = cos_sin_cache_offsets.size(0);
|
||||||
@@ -188,39 +180,24 @@ void batched_rotary_embedding(
|
|||||||
int64_t key_stride = key.stride(-2);
|
int64_t key_stride = key.stride(-2);
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
|
||||||
query.scalar_type(),
|
|
||||||
"rotary_embedding",
|
|
||||||
[&] {
|
|
||||||
if (is_neox) {
|
if (is_neox) {
|
||||||
vllm::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
vllm::batched_rotary_embedding_kernel<scalar_t, true>
|
||||||
positions.data_ptr<int64_t>(),
|
<<<grid, block, 0, stream>>>(
|
||||||
query.data_ptr<scalar_t>(),
|
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||||
key.data_ptr<scalar_t>(),
|
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
||||||
cos_sin_cache_offsets.data_ptr<int64_t>(),
|
key_stride, num_heads, num_kv_heads, head_size);
|
||||||
rot_dim,
|
|
||||||
query_stride,
|
|
||||||
key_stride,
|
|
||||||
num_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
head_size);
|
|
||||||
} else {
|
} else {
|
||||||
vllm::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
|
vllm::batched_rotary_embedding_kernel<scalar_t, false>
|
||||||
positions.data_ptr<int64_t>(),
|
<<<grid, block, 0, stream>>>(
|
||||||
query.data_ptr<scalar_t>(),
|
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||||
key.data_ptr<scalar_t>(),
|
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
||||||
cos_sin_cache_offsets.data_ptr<int64_t>(),
|
key_stride, num_heads, num_kv_heads, head_size);
|
||||||
rot_dim,
|
|
||||||
query_stride,
|
|
||||||
key_stride,
|
|
||||||
num_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
head_size);
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||||||
f(in_T, out_T, W_T, narrow, 2752) \
|
f(in_T, out_T, W_T, narrow, 2752) \
|
||||||
f(in_T, out_T, W_T, narrow, 2816) \
|
f(in_T, out_T, W_T, narrow, 2816) \
|
||||||
f(in_T, out_T, W_T, narrow, 3072) \
|
f(in_T, out_T, W_T, narrow, 3072) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 3328) \
|
||||||
f(in_T, out_T, W_T, narrow, 3456) \
|
f(in_T, out_T, W_T, narrow, 3456) \
|
||||||
f(in_T, out_T, W_T, narrow, 3584) \
|
f(in_T, out_T, W_T, narrow, 3584) \
|
||||||
f(in_T, out_T, W_T, narrow, 4096) \
|
f(in_T, out_T, W_T, narrow, 4096) \
|
||||||
@@ -36,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||||||
f(in_T, out_T, W_T, narrow, 5504) \
|
f(in_T, out_T, W_T, narrow, 5504) \
|
||||||
f(in_T, out_T, W_T, narrow, 5632) \
|
f(in_T, out_T, W_T, narrow, 5632) \
|
||||||
f(in_T, out_T, W_T, narrow, 6144) \
|
f(in_T, out_T, W_T, narrow, 6144) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 6400) \
|
||||||
f(in_T, out_T, W_T, narrow, 6848) \
|
f(in_T, out_T, W_T, narrow, 6848) \
|
||||||
f(in_T, out_T, W_T, narrow, 6912) \
|
f(in_T, out_T, W_T, narrow, 6912) \
|
||||||
f(in_T, out_T, W_T, narrow, 7168) \
|
f(in_T, out_T, W_T, narrow, 7168) \
|
||||||
@@ -53,6 +55,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||||||
f(in_T, out_T, W_T, narrow, 22016) \
|
f(in_T, out_T, W_T, narrow, 22016) \
|
||||||
f(in_T, out_T, W_T, narrow, 24576) \
|
f(in_T, out_T, W_T, narrow, 24576) \
|
||||||
f(in_T, out_T, W_T, narrow, 27392) \
|
f(in_T, out_T, W_T, narrow, 27392) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 27648) \
|
||||||
f(in_T, out_T, W_T, narrow, 28672) \
|
f(in_T, out_T, W_T, narrow, 28672) \
|
||||||
f(in_T, out_T, W_T, narrow, 32000) \
|
f(in_T, out_T, W_T, narrow, 32000) \
|
||||||
f(in_T, out_T, W_T, narrow, 32256) \
|
f(in_T, out_T, W_T, narrow, 32256) \
|
||||||
@@ -96,6 +99,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||||||
f(in_T, out_T, W_T, 2752, narrow) \
|
f(in_T, out_T, W_T, 2752, narrow) \
|
||||||
f(in_T, out_T, W_T, 2816, narrow) \
|
f(in_T, out_T, W_T, 2816, narrow) \
|
||||||
f(in_T, out_T, W_T, 3072, narrow) \
|
f(in_T, out_T, W_T, 3072, narrow) \
|
||||||
|
f(in_T, out_T, W_T, 3328, narrow) \
|
||||||
f(in_T, out_T, W_T, 3456, narrow) \
|
f(in_T, out_T, W_T, 3456, narrow) \
|
||||||
f(in_T, out_T, W_T, 3584, narrow) \
|
f(in_T, out_T, W_T, 3584, narrow) \
|
||||||
f(in_T, out_T, W_T, 4096, narrow) \
|
f(in_T, out_T, W_T, 4096, narrow) \
|
||||||
@@ -104,6 +108,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||||||
f(in_T, out_T, W_T, 5504, narrow) \
|
f(in_T, out_T, W_T, 5504, narrow) \
|
||||||
f(in_T, out_T, W_T, 5632, narrow) \
|
f(in_T, out_T, W_T, 5632, narrow) \
|
||||||
f(in_T, out_T, W_T, 6144, narrow) \
|
f(in_T, out_T, W_T, 6144, narrow) \
|
||||||
|
f(in_T, out_T, W_T, 6400, narrow) \
|
||||||
f(in_T, out_T, W_T, 6848, narrow) \
|
f(in_T, out_T, W_T, 6848, narrow) \
|
||||||
f(in_T, out_T, W_T, 6912, narrow) \
|
f(in_T, out_T, W_T, 6912, narrow) \
|
||||||
f(in_T, out_T, W_T, 7168, narrow) \
|
f(in_T, out_T, W_T, 7168, narrow) \
|
||||||
@@ -121,6 +126,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||||||
f(in_T, out_T, W_T, 22016, narrow) \
|
f(in_T, out_T, W_T, 22016, narrow) \
|
||||||
f(in_T, out_T, W_T, 24576, narrow) \
|
f(in_T, out_T, W_T, 24576, narrow) \
|
||||||
f(in_T, out_T, W_T, 27392, narrow) \
|
f(in_T, out_T, W_T, 27392, narrow) \
|
||||||
|
f(in_T, out_T, W_T, 27648, narrow) \
|
||||||
f(in_T, out_T, W_T, 28672, narrow) \
|
f(in_T, out_T, W_T, 28672, narrow) \
|
||||||
f(in_T, out_T, W_T, 32000, narrow) \
|
f(in_T, out_T, W_T, 32000, narrow) \
|
||||||
f(in_T, out_T, W_T, 32256, narrow) \
|
f(in_T, out_T, W_T, 32256, narrow) \
|
||||||
|
|||||||
@@ -1,8 +1,14 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#ifndef USE_ROCM
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
|
#else
|
||||||
|
#include <hip/hip_cooperative_groups.h>
|
||||||
|
#endif
|
||||||
|
#ifndef USE_ROCM
|
||||||
#include <cuda/pipeline>
|
#include <cuda/pipeline>
|
||||||
|
#endif
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
@@ -11,6 +17,24 @@
|
|||||||
|
|
||||||
namespace cg = cooperative_groups;
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
template <size_t len>
|
||||||
|
__host__ __device__
|
||||||
|
inline void* memcpy_blocking(void *dst, const void *src) {
|
||||||
|
// Does not handle the case of long datatypes
|
||||||
|
char *d = reinterpret_cast<char *>(dst);
|
||||||
|
const char *s = reinterpret_cast<const char *>(src);
|
||||||
|
size_t i = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (i = 0; i < len; ++i) {
|
||||||
|
d[i] = s[i];
|
||||||
|
}
|
||||||
|
return dst;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
|
||||||
// nthrs = (32, 4)
|
// nthrs = (32, 4)
|
||||||
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
|
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
|
||||||
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
|
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
|
||||||
@@ -141,6 +165,81 @@ bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
|
||||||
|
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
|
||||||
|
typename out_T, typename W_T>
|
||||||
|
__global__ void
|
||||||
|
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||||
|
const W_T *__restrict__ W,
|
||||||
|
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||||
|
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
|
||||||
|
float scale) {
|
||||||
|
size_t batch_idx = blockIdx.y;
|
||||||
|
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
|
||||||
|
if (idx < 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t j = blockIdx.x;
|
||||||
|
constexpr size_t tile_size = tx * ty * vec_size;
|
||||||
|
constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size;
|
||||||
|
__shared__ float y_warpwise[ty];
|
||||||
|
|
||||||
|
float y = 0;
|
||||||
|
vec_t<in_T, vec_size> x_vec;
|
||||||
|
vec_t<W_T, vec_size> w_vec;
|
||||||
|
size_t tile_idx;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {
|
||||||
|
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
|
||||||
|
x_vec.load(X + (batch_idx * feat_in) +
|
||||||
|
tile_idx * tile_size +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||||
|
w_vec.load(W + (idx * feat_out + j) * feat_in +
|
||||||
|
tile_idx * tile_size +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
float sum = 0.f;
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t i = 0; i < vec_size; ++i) {
|
||||||
|
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||||
|
sum += VLLM_SHFL_DOWN_SYNC(sum, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
|
||||||
|
y += sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
y_warpwise[threadIdx.y] = y;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float y_write = 0.f;
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t i = 0; i < ty; ++i) {
|
||||||
|
y_write += y_warpwise[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// write Y;
|
||||||
|
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
||||||
|
size_t y_idx = batch_idx * full_y_size + y_offset + j;
|
||||||
|
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(y_write));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
// nthrs = (2, 16, 4)
|
// nthrs = (2, 16, 4)
|
||||||
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
|
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
|
||||||
typename in_T, typename out_T, typename W_T>
|
typename in_T, typename out_T, typename W_T>
|
||||||
@@ -172,7 +271,11 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||||||
float sum = 0.f;
|
float sum = 0.f;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (size_t i = 0; i < vec_size; ++i) {
|
for (size_t i = 0; i < vec_size; ++i) {
|
||||||
|
#ifndef USE_ROCM
|
||||||
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||||
|
#else
|
||||||
|
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
cg::thread_block_tile g = cg::tiled_partition<tx>(block);
|
cg::thread_block_tile g = cg::tiled_partition<tx>(block);
|
||||||
@@ -183,8 +286,14 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||||||
sum = g.shfl(sum, 0);
|
sum = g.shfl(sum, 0);
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
|
#ifndef USE_ROCM
|
||||||
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
|
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
|
||||||
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
|
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
|
||||||
|
#else
|
||||||
|
size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
|
||||||
|
threadIdx.z * ty + threadIdx.y;
|
||||||
|
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(sum));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -236,6 +345,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||||||
scale);
|
scale);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
#ifndef USE_ROCM
|
||||||
static_assert(feat_in % (vec_size * 32) == 0 ||
|
static_assert(feat_in % (vec_size * 32) == 0 ||
|
||||||
feat_in % (vec_size * 16) == 0 ||
|
feat_in % (vec_size * 16) == 0 ||
|
||||||
feat_in % (vec_size * 8) == 0);
|
feat_in % (vec_size * 8) == 0);
|
||||||
@@ -279,6 +389,50 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||||||
full_y_size, num_layers, layer_idx,
|
full_y_size, num_layers, layer_idx,
|
||||||
scale);
|
scale);
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
constexpr size_t rocm_warp_size = warpSize;
|
||||||
|
|
||||||
|
#define CHECK_INPUT_TILEABLE_BY(vec_size_) \
|
||||||
|
feat_in % (rocm_warp_size * vec_size_) == 0
|
||||||
|
|
||||||
|
#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \
|
||||||
|
if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \
|
||||||
|
constexpr size_t vec_size_shrink = vec_size_; \
|
||||||
|
constexpr int tx = tx_; \
|
||||||
|
constexpr int ty = ty_; \
|
||||||
|
dim3 nblks(feat_out, batch_size); \
|
||||||
|
dim3 nthrs(tx, ty); \
|
||||||
|
bgmv_shrink_kernel<feat_in, feat_out, vec_size_shrink, \
|
||||||
|
vec_size_shrink * sizeof(in_T), \
|
||||||
|
vec_size_shrink * sizeof(W_T), \
|
||||||
|
tx, ty, tz> \
|
||||||
|
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset, \
|
||||||
|
full_y_size, num_layers, layer_idx, \
|
||||||
|
scale); \
|
||||||
|
}
|
||||||
|
|
||||||
|
static_assert(CHECK_INPUT_TILEABLE_BY(32) ||
|
||||||
|
CHECK_INPUT_TILEABLE_BY(16) ||
|
||||||
|
CHECK_INPUT_TILEABLE_BY( 8) ||
|
||||||
|
CHECK_INPUT_TILEABLE_BY( 4) ||
|
||||||
|
CHECK_INPUT_TILEABLE_BY( 2) ||
|
||||||
|
CHECK_INPUT_TILEABLE_BY( 1));
|
||||||
|
|
||||||
|
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size)
|
||||||
|
else
|
||||||
|
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size)
|
||||||
|
else
|
||||||
|
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size)
|
||||||
|
else
|
||||||
|
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4)
|
||||||
|
else
|
||||||
|
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2)
|
||||||
|
else
|
||||||
|
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1)
|
||||||
|
|
||||||
|
#undef CHECK_INPUT_TILEABLE_BY
|
||||||
|
#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
#ifndef VEC_DTYPES_CUH_
|
#ifndef VEC_DTYPES_CUH_
|
||||||
#define VEC_DTYPES_CUH_
|
#define VEC_DTYPES_CUH_
|
||||||
|
|
||||||
#include <cuda_bf16.h>
|
|
||||||
#include <cuda_fp16.h>
|
|
||||||
#ifdef FLASHINFER_USE_FP8
|
#ifdef FLASHINFER_USE_FP8
|
||||||
#include <cuda_fp8.h>
|
#include <cuda_fp8.h>
|
||||||
#endif
|
#endif
|
||||||
@@ -10,6 +8,9 @@
|
|||||||
|
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "../type_convert.h"
|
||||||
|
#include "../../cuda_compat.h"
|
||||||
|
|
||||||
#define FLASHINFER_INLINE \
|
#define FLASHINFER_INLINE \
|
||||||
inline __attribute__((always_inline)) __device__ __host__
|
inline __attribute__((always_inline)) __device__ __host__
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
#include <cuda_bf16.h>
|
#include <torch/all.h>
|
||||||
#include <cuda_fp16.h>
|
|
||||||
#include <torch/extension.h>
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "type_convert.h"
|
||||||
|
#include "../cuda_compat.h"
|
||||||
#include "bgmv/bgmv_config.h"
|
#include "bgmv/bgmv_config.h"
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
//====== utils ======
|
//====== utils ======
|
||||||
|
|
||||||
@@ -89,7 +88,7 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||||
torch::Tensor indicies, int64_t layer_idx, float scale) {
|
torch::Tensor indicies, int64_t layer_idx, double scale) {
|
||||||
CHECK_INPUT(y);
|
CHECK_INPUT(y);
|
||||||
CHECK_INPUT(x);
|
CHECK_INPUT(x);
|
||||||
CHECK_INPUT(w);
|
CHECK_INPUT(w);
|
||||||
@@ -321,7 +320,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
|||||||
|
|
||||||
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||||
torch::Tensor indicies, int64_t layer_idx,
|
torch::Tensor indicies, int64_t layer_idx,
|
||||||
float scale, int64_t h_in, int64_t h_out,
|
double scale, int64_t h_in, int64_t h_out,
|
||||||
int64_t y_offset) {
|
int64_t y_offset) {
|
||||||
CHECK_INPUT(y);
|
CHECK_INPUT(y);
|
||||||
CHECK_INPUT(x);
|
CHECK_INPUT(x);
|
||||||
@@ -568,15 +567,3 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
|||||||
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
|
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
|
||||||
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
|
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
//====== pybind ======
|
|
||||||
|
|
||||||
#define DEFINE_pybind(name) m.def(#name, &name, #name);
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
|
|
||||||
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
|
|
||||||
"dispatch_bgmv_low_level");
|
|
||||||
}
|
|
||||||
11
csrc/punica/punica_ops.h
Normal file
11
csrc/punica/punica_ops.h
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||||
|
torch::Tensor indicies, int64_t layer_idx, double scale);
|
||||||
|
|
||||||
|
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||||
|
torch::Tensor indicies, int64_t layer_idx,
|
||||||
|
double scale, int64_t h_in, int64_t h_out,
|
||||||
|
int64_t y_offset);
|
||||||
18
csrc/punica/torch_bindings.cpp
Normal file
18
csrc/punica/torch_bindings.cpp
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
#include "registration.h"
|
||||||
|
#include "punica_ops.h"
|
||||||
|
|
||||||
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def(
|
||||||
|
"dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int "
|
||||||
|
"layer_idx, float scale) -> ()");
|
||||||
|
m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv);
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w,"
|
||||||
|
"Tensor indicies, int layer_idx,"
|
||||||
|
"float scale, int h_in, int h_out,"
|
||||||
|
"int y_offset) -> ()");
|
||||||
|
m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level);
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||||
82
csrc/punica/type_convert.h
Normal file
82
csrc/punica/type_convert.h
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
#ifndef CSRC__PUNICA__TYPE_CONVERT_H__
|
||||||
|
#define CSRC__PUNICA__TYPE_CONVERT_H__
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
#include <hip/hip_bf16.h>
|
||||||
|
#include <hip/hip_fp16.h>
|
||||||
|
|
||||||
|
#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__
|
||||||
|
|
||||||
|
typedef __half nv_half;
|
||||||
|
typedef __hip_bfloat16 nv_bfloat16;
|
||||||
|
typedef __hip_bfloat162 nv_bfloat162;
|
||||||
|
|
||||||
|
__TYPE_CONVERT__HOST_DEVICE__
|
||||||
|
inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) {
|
||||||
|
return __hip_bfloat162{val, val};
|
||||||
|
}
|
||||||
|
|
||||||
|
__TYPE_CONVERT__HOST_DEVICE__
|
||||||
|
inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) {
|
||||||
|
return __hip_bfloat162{vall, valr};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T_src, typename T_dst>
|
||||||
|
__TYPE_CONVERT__HOST_DEVICE__
|
||||||
|
inline T_dst convert_type(T_src val) {
|
||||||
|
return static_cast<T_dst>(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__TYPE_CONVERT__HOST_DEVICE__
|
||||||
|
inline float convert_type<__half, float>(__half val) {
|
||||||
|
return __half2float(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__TYPE_CONVERT__HOST_DEVICE__
|
||||||
|
inline __half convert_type<float, __half>(float val) {
|
||||||
|
return __float2half(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__TYPE_CONVERT__HOST_DEVICE__
|
||||||
|
inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) {
|
||||||
|
return __bfloat162float(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__TYPE_CONVERT__HOST_DEVICE__
|
||||||
|
inline __hip_bfloat16 convert_type<float, __hip_bfloat16>(float val) {
|
||||||
|
return __float2bfloat16(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__TYPE_CONVERT__HOST_DEVICE__
|
||||||
|
inline T vllm_add(T a, T b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__TYPE_CONVERT__HOST_DEVICE__
|
||||||
|
inline __half vllm_add<__half>(__half a, __half b) {
|
||||||
|
return __hadd(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__TYPE_CONVERT__HOST_DEVICE__
|
||||||
|
inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) {
|
||||||
|
return __hadd(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef __TYPE_CONVERT__HOST_DEVICE__
|
||||||
|
|
||||||
|
#endif // USE_ROCM
|
||||||
|
|
||||||
|
#endif // CSRC__PUNICA__TYPE_CONVERT_H__
|
||||||
136
csrc/pybind.cpp
136
csrc/pybind.cpp
@@ -1,136 +0,0 @@
|
|||||||
#include "cache.h"
|
|
||||||
#include "cuda_utils.h"
|
|
||||||
#include "ops.h"
|
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
// vLLM custom ops
|
|
||||||
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
|
||||||
|
|
||||||
// Attention ops
|
|
||||||
ops.def(
|
|
||||||
"paged_attention_v1",
|
|
||||||
&paged_attention_v1,
|
|
||||||
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
|
||||||
ops.def(
|
|
||||||
"paged_attention_v2",
|
|
||||||
&paged_attention_v2,
|
|
||||||
"PagedAttention V2.");
|
|
||||||
|
|
||||||
// Activation ops
|
|
||||||
ops.def(
|
|
||||||
"silu_and_mul",
|
|
||||||
&silu_and_mul,
|
|
||||||
"Activation function used in SwiGLU.");
|
|
||||||
ops.def(
|
|
||||||
"gelu_and_mul",
|
|
||||||
&gelu_and_mul,
|
|
||||||
"Activation function used in GeGLU with `none` approximation.");
|
|
||||||
ops.def(
|
|
||||||
"gelu_tanh_and_mul",
|
|
||||||
&gelu_tanh_and_mul,
|
|
||||||
"Activation function used in GeGLU with `tanh` approximation.");
|
|
||||||
ops.def(
|
|
||||||
"gelu_new",
|
|
||||||
&gelu_new,
|
|
||||||
"GELU implementation used in GPT-2.");
|
|
||||||
ops.def(
|
|
||||||
"gelu_fast",
|
|
||||||
&gelu_fast,
|
|
||||||
"Approximate GELU implementation.");
|
|
||||||
|
|
||||||
// Layernorm
|
|
||||||
ops.def(
|
|
||||||
"rms_norm",
|
|
||||||
&rms_norm,
|
|
||||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
|
||||||
|
|
||||||
ops.def(
|
|
||||||
"fused_add_rms_norm",
|
|
||||||
&fused_add_rms_norm,
|
|
||||||
"In-place fused Add and RMS Normalization");
|
|
||||||
|
|
||||||
// Rotary embedding
|
|
||||||
ops.def(
|
|
||||||
"rotary_embedding",
|
|
||||||
&rotary_embedding,
|
|
||||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
|
||||||
|
|
||||||
ops.def(
|
|
||||||
"batched_rotary_embedding",
|
|
||||||
&batched_rotary_embedding,
|
|
||||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)");
|
|
||||||
|
|
||||||
// Quantization ops
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
|
|
||||||
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
|
|
||||||
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
|
||||||
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
|
|
||||||
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ");
|
|
||||||
ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ");
|
|
||||||
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
|
||||||
#endif
|
|
||||||
|
|
||||||
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
|
||||||
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
|
||||||
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
|
||||||
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor");
|
|
||||||
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
|
|
||||||
ops.def(
|
|
||||||
"moe_align_block_size",
|
|
||||||
&moe_align_block_size,
|
|
||||||
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
|
|
||||||
|
|
||||||
// Cache ops
|
|
||||||
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
|
||||||
cache_ops.def(
|
|
||||||
"swap_blocks",
|
|
||||||
&swap_blocks,
|
|
||||||
"Swap in (out) the cache blocks from src to dst");
|
|
||||||
cache_ops.def(
|
|
||||||
"copy_blocks",
|
|
||||||
©_blocks,
|
|
||||||
"Copy the cache blocks from src to dst");
|
|
||||||
cache_ops.def(
|
|
||||||
"reshape_and_cache",
|
|
||||||
&reshape_and_cache,
|
|
||||||
"Reshape the key and value tensors and cache them");
|
|
||||||
cache_ops.def(
|
|
||||||
"reshape_and_cache_flash",
|
|
||||||
&reshape_and_cache_flash,
|
|
||||||
"Reshape the key and value tensors and cache them");
|
|
||||||
cache_ops.def(
|
|
||||||
"convert_fp8",
|
|
||||||
&convert_fp8,
|
|
||||||
"Convert the key and value cache to fp8 data type");
|
|
||||||
|
|
||||||
// Cuda utils
|
|
||||||
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
|
|
||||||
cuda_utils.def(
|
|
||||||
"get_device_attribute",
|
|
||||||
&get_device_attribute,
|
|
||||||
"Gets the specified device attribute.");
|
|
||||||
|
|
||||||
cuda_utils.def(
|
|
||||||
"get_max_shared_memory_per_block_device_attribute",
|
|
||||||
&get_max_shared_memory_per_block_device_attribute,
|
|
||||||
"Gets the maximum shared memory per block device attribute.");
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
// Custom all-reduce kernels
|
|
||||||
pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce");
|
|
||||||
custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar");
|
|
||||||
custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar");
|
|
||||||
custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg");
|
|
||||||
custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg");
|
|
||||||
custom_ar.def("dispose", &dispose, "dispose");
|
|
||||||
custom_ar.def("meta_size", &meta_size, "meta_size");
|
|
||||||
custom_ar.def("register_buffer", ®ister_buffer, "register_buffer");
|
|
||||||
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta,
|
|
||||||
"get_graph_buffer_ipc_meta");
|
|
||||||
custom_ar.def("register_graph_buffers", ®ister_graph_buffers,
|
|
||||||
"register_graph_buffers");
|
|
||||||
#endif
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -18,37 +18,33 @@
|
|||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/all.h>
|
||||||
#include <c10/cuda/CUDAStream.h>
|
#include <c10/cuda/CUDAStream.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
namespace aqlm {
|
namespace aqlm {
|
||||||
|
|
||||||
__global__ void Code1x16MatVec(
|
__global__ void Code1x16MatVec(
|
||||||
const int4* __restrict__ A,
|
const int4* __restrict__ A, const int4* __restrict__ B,
|
||||||
const int4* __restrict__ B,
|
int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m,
|
||||||
int4* __restrict__ C,
|
|
||||||
const int4* __restrict__ codebook,
|
|
||||||
const int prob_m,
|
|
||||||
const int prob_k,
|
const int prob_k,
|
||||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||||
|
// codebook, at most 3 long.
|
||||||
const int codebook_stride // as int4.
|
const int codebook_stride // as int4.
|
||||||
) {
|
) {
|
||||||
int a_gl_stride = prob_k / 8 / 8;
|
int a_gl_stride = prob_k / 8 / 8;
|
||||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||||
bool pred = a_gl_rd < prob_m;
|
bool pred = a_gl_rd < prob_m;
|
||||||
|
|
||||||
if (pred)
|
if (pred) {
|
||||||
{
|
// advance to the correct codebook, this easy because we only multiply one
|
||||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
// column of the codebook.
|
||||||
auto codebook_size = &codebook_a_sizes.x;
|
auto codebook_size = &codebook_a_sizes.x;
|
||||||
while (a_gl_rd >= *codebook_size)
|
while (a_gl_rd >= *codebook_size) {
|
||||||
{
|
|
||||||
codebook += codebook_stride;
|
codebook += codebook_stride;
|
||||||
++codebook_size;
|
++codebook_size;
|
||||||
}
|
}
|
||||||
@@ -67,8 +63,7 @@ __global__ void Code1x16MatVec(
|
|||||||
// We pad shared memory to avoid bank conflicts during reads
|
// We pad shared memory to avoid bank conflicts during reads
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
|
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
|
||||||
if (b_gl_rd + i < prob_k / 8)
|
if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
||||||
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
b_gl_rd += 32 * 8;
|
b_gl_rd += 32 * 8;
|
||||||
@@ -79,19 +74,16 @@ __global__ void Code1x16MatVec(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
uint32_t dec[4];
|
uint32_t dec[4];
|
||||||
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
|
// We bypass the L1 cache to avoid massive amounts of memory streaming
|
||||||
// actually help us; this brings > 2x speedup.
|
// that doesn't actually help us; this brings > 2x speedup.
|
||||||
asm volatile (
|
asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
||||||
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
|
||||||
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
||||||
: "l"((void*) &codebook[enc[i]])
|
: "l"((void*)&codebook[enc[i]]));
|
||||||
);
|
|
||||||
half2* a = reinterpret_cast<half2*>(&dec);
|
half2* a = reinterpret_cast<half2*>(&dec);
|
||||||
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
||||||
half2 res2 = {};
|
half2 res2 = {};
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 4; j++)
|
for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2);
|
||||||
res2 = __hfma2(a[j], b[j], res2);
|
|
||||||
res += __half2float(res2.x) + __half2float(res2.y);
|
res += __half2float(res2.x) + __half2float(res2.y);
|
||||||
b_sh_rd++;
|
b_sh_rd++;
|
||||||
}
|
}
|
||||||
@@ -101,21 +93,18 @@ __global__ void Code1x16MatVec(
|
|||||||
|
|
||||||
if (pred) {
|
if (pred) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 16; i > 0; i /= 2)
|
for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
|
||||||
res += __shfl_down_sync(0xffffffff, res, i);
|
|
||||||
if (threadIdx.x % 32 == 0)
|
if (threadIdx.x % 32 == 0)
|
||||||
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
|
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void Code2x8MatVec(
|
__global__ void Code2x8MatVec(
|
||||||
const int4* __restrict__ A,
|
const int4* __restrict__ A, const int4* __restrict__ B,
|
||||||
const int4* __restrict__ B,
|
int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m,
|
||||||
int4* __restrict__ C,
|
|
||||||
const int4* __restrict__ codebook,
|
|
||||||
int prob_m,
|
|
||||||
int prob_k,
|
int prob_k,
|
||||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||||
|
// codebook, at most 3 long.
|
||||||
const int codebook_stride // as int4.
|
const int codebook_stride // as int4.
|
||||||
|
|
||||||
) {
|
) {
|
||||||
@@ -123,12 +112,11 @@ __global__ void Code2x8MatVec(
|
|||||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||||
bool pred = a_gl_rd < prob_m;
|
bool pred = a_gl_rd < prob_m;
|
||||||
|
|
||||||
if (pred)
|
if (pred) {
|
||||||
{
|
// advance to the correct codebook, this easy because we only multiply one
|
||||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
// column of the codebook.
|
||||||
auto codebook_size = &codebook_a_sizes.x;
|
auto codebook_size = &codebook_a_sizes.x;
|
||||||
while (a_gl_rd >= *codebook_size)
|
while (a_gl_rd >= *codebook_size) {
|
||||||
{
|
|
||||||
codebook += codebook_stride;
|
codebook += codebook_stride;
|
||||||
++codebook_size;
|
++codebook_size;
|
||||||
}
|
}
|
||||||
@@ -149,8 +137,7 @@ __global__ void Code2x8MatVec(
|
|||||||
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
|
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
|
||||||
int4 dec = codebook[i];
|
int4 dec = codebook[i];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 8; j++)
|
for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
|
||||||
sh_code[8 * i + (j + lane) % 8] = dec;
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
@@ -161,8 +148,7 @@ __global__ void Code2x8MatVec(
|
|||||||
// We pad shared memory to avoid bank conflicts during reads
|
// We pad shared memory to avoid bank conflicts during reads
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
|
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
|
||||||
if (b_gl_rd + i < prob_k / 8)
|
if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
||||||
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
b_gl_rd += 32 * 8;
|
b_gl_rd += 32 * 8;
|
||||||
@@ -172,8 +158,10 @@ __global__ void Code2x8MatVec(
|
|||||||
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
|
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
half2* a0 =
|
||||||
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
||||||
|
half2* a1 =
|
||||||
|
reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
||||||
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
||||||
half2 res2 = {};
|
half2 res2 = {};
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@@ -188,33 +176,28 @@ __global__ void Code2x8MatVec(
|
|||||||
|
|
||||||
if (pred) {
|
if (pred) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 16; i > 0; i /= 2)
|
for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
|
||||||
res += __shfl_down_sync(0xffffffff, res, i);
|
|
||||||
if (threadIdx.x % 32 == 0)
|
if (threadIdx.x % 32 == 0)
|
||||||
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
|
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
__global__ void Code1x16Dequant(
|
__global__ void Code1x16Dequant(
|
||||||
const int4* __restrict__ A,
|
const int4* __restrict__ A, int4* __restrict__ C,
|
||||||
int4* __restrict__ C,
|
const int4* __restrict__ codebook, int prob_m, int prob_k,
|
||||||
const int4* __restrict__ codebook,
|
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||||
int prob_m,
|
// codebook, at most 3 long, sums to m.
|
||||||
int prob_k,
|
|
||||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, sums to m.
|
|
||||||
const int codebook_stride // as int4
|
const int codebook_stride // as int4
|
||||||
) {
|
) {
|
||||||
int a_gl_stride = prob_k / 8 / 8;
|
int a_gl_stride = prob_k / 8 / 8;
|
||||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||||
bool pred = a_gl_rd < prob_m;
|
bool pred = a_gl_rd < prob_m;
|
||||||
|
|
||||||
if (pred)
|
if (pred) {
|
||||||
{
|
// advance to the correct codebook, this easy because we only multiply one
|
||||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
// column of the codebook.
|
||||||
auto codebook_size = &codebook_a_sizes.x;
|
auto codebook_size = &codebook_a_sizes.x;
|
||||||
while (a_gl_rd >= *codebook_size)
|
while (a_gl_rd >= *codebook_size) {
|
||||||
{
|
|
||||||
codebook += codebook_stride;
|
codebook += codebook_stride;
|
||||||
++codebook_size;
|
++codebook_size;
|
||||||
}
|
}
|
||||||
@@ -235,13 +218,11 @@ __global__ void Code1x16Dequant(
|
|||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
int4 chunk;
|
int4 chunk;
|
||||||
auto dec = reinterpret_cast<uint32_t*>(&chunk);
|
auto dec = reinterpret_cast<uint32_t*>(&chunk);
|
||||||
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
|
// We bypass the L1 cache to avoid massive amounts of memory streaming
|
||||||
// actually help us; this brings > 2x speedup.
|
// that doesn't actually help us; this brings > 2x speedup.
|
||||||
asm volatile (
|
asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
||||||
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
|
||||||
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
||||||
: "l"((void*) &codebook[enc[i]])
|
: "l"((void*)&codebook[enc[i]]));
|
||||||
);
|
|
||||||
|
|
||||||
C[a_gl_rd * 8 + i] = chunk;
|
C[a_gl_rd * 8 + i] = chunk;
|
||||||
}
|
}
|
||||||
@@ -250,26 +231,23 @@ __global__ void Code1x16Dequant(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
__global__ void Code2x8Dequant(
|
__global__ void Code2x8Dequant(
|
||||||
const int4* __restrict__ A,
|
const int4* __restrict__ A, int4* __restrict__ C,
|
||||||
int4* __restrict__ C,
|
const int4* __restrict__ codebook, int prob_m, int prob_k,
|
||||||
const int4* __restrict__ codebook,
|
const int4
|
||||||
int prob_m,
|
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
|
||||||
int prob_k,
|
// most 3 long, corresponds to cols.
|
||||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
|
|
||||||
const int codebook_stride // as int4
|
const int codebook_stride // as int4
|
||||||
) {
|
) {
|
||||||
int a_gl_stride = prob_k / 8 / 8;
|
int a_gl_stride = prob_k / 8 / 8;
|
||||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||||
bool pred = a_gl_rd < prob_m;
|
bool pred = a_gl_rd < prob_m;
|
||||||
|
|
||||||
if (pred)
|
if (pred) {
|
||||||
{
|
// advance to the correct codebook, this easy because we only multiply one
|
||||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
// column of the codebook.
|
||||||
auto codebook_size = &codebook_a_sizes.x;
|
auto codebook_size = &codebook_a_sizes.x;
|
||||||
while (a_gl_rd >= *codebook_size)
|
while (a_gl_rd >= *codebook_size) {
|
||||||
{
|
|
||||||
codebook += codebook_stride;
|
codebook += codebook_stride;
|
||||||
++codebook_size;
|
++codebook_size;
|
||||||
}
|
}
|
||||||
@@ -291,8 +269,7 @@ __global__ void Code2x8Dequant(
|
|||||||
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
|
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
|
||||||
int4 dec = codebook[i];
|
int4 dec = codebook[i];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 8; j++)
|
for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
|
||||||
sh_code[8 * i + (j + lane) % 8] = dec;
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
@@ -305,8 +282,10 @@ __global__ void Code2x8Dequant(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
int4 chunk;
|
int4 chunk;
|
||||||
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
half2* a0 =
|
||||||
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
||||||
|
half2* a1 =
|
||||||
|
reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 4; j++)
|
for (int j = 0; j < 4; j++)
|
||||||
reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]);
|
reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]);
|
||||||
@@ -317,22 +296,15 @@ __global__ void Code2x8Dequant(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int ceildiv(int a, int b) {
|
inline int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
||||||
return (a + b - 1) / b;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int THREAD_M = 16;
|
const int THREAD_M = 16;
|
||||||
|
|
||||||
void code1x16_matvec_cuda(
|
void code1x16_matvec_cuda(const void* __restrict__ A,
|
||||||
const void* __restrict__ A,
|
const void* __restrict__ B, void* __restrict__ C,
|
||||||
const void* __restrict__ B,
|
const void* __restrict__ codebook, int prob_m,
|
||||||
void* __restrict__ C,
|
int prob_k, const int4 codebook_a_sizes,
|
||||||
const void* __restrict__ codebook,
|
const int codebook_stride) {
|
||||||
int prob_m,
|
|
||||||
int prob_k,
|
|
||||||
const int4 codebook_a_sizes,
|
|
||||||
const int codebook_stride
|
|
||||||
) {
|
|
||||||
int sms;
|
int sms;
|
||||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
||||||
int waves = 0;
|
int waves = 0;
|
||||||
@@ -346,27 +318,15 @@ void code1x16_matvec_cuda(
|
|||||||
int threads = 32 * thread_m;
|
int threads = 32 * thread_m;
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
Code1x16MatVec<<<blocks, threads, 16 * 32 * 9, stream>>>(
|
Code1x16MatVec<<<blocks, threads, 16 * 32 * 9, stream>>>(
|
||||||
(const int4*) A,
|
(const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
|
||||||
(const int4*) B,
|
prob_k, codebook_a_sizes, codebook_stride);
|
||||||
(int4*) C,
|
|
||||||
(const int4*) codebook,
|
|
||||||
prob_m,
|
|
||||||
prob_k,
|
|
||||||
codebook_a_sizes,
|
|
||||||
codebook_stride
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void code2x8_matvec_cuda(
|
void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B,
|
||||||
const void* __restrict__ A,
|
|
||||||
const void* __restrict__ B,
|
|
||||||
void* __restrict__ C,
|
void* __restrict__ C,
|
||||||
const void* __restrict__ codebook,
|
const void* __restrict__ codebook, int prob_m,
|
||||||
int prob_m,
|
int prob_k, const int4 codebook_a_sizes,
|
||||||
int prob_k,
|
const int codebook_stride) {
|
||||||
const int4 codebook_a_sizes,
|
|
||||||
const int codebook_stride
|
|
||||||
) {
|
|
||||||
int sms;
|
int sms;
|
||||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
||||||
int waves = 0;
|
int waves = 0;
|
||||||
@@ -379,29 +339,19 @@ void code2x8_matvec_cuda(
|
|||||||
int blocks = ceildiv(prob_m, thread_m);
|
int blocks = ceildiv(prob_m, thread_m);
|
||||||
int threads = 32 * thread_m;
|
int threads = 32 * thread_m;
|
||||||
int shared = 16 * (2 * 256 * 8 + 32 * 9);
|
int shared = 16 * (2 * 256 * 8 + 32 * 9);
|
||||||
cudaFuncSetAttribute(
|
cudaFuncSetAttribute(Code2x8MatVec,
|
||||||
Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
|
cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
|
||||||
);
|
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
Code2x8MatVec<<<blocks, threads, shared, stream>>>(
|
Code2x8MatVec<<<blocks, threads, shared, stream>>>(
|
||||||
(const int4*) A,
|
(const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
|
||||||
(const int4*) B,
|
prob_k, codebook_a_sizes, codebook_stride);
|
||||||
(int4*) C,
|
|
||||||
(const int4*) codebook,
|
|
||||||
prob_m,
|
|
||||||
prob_k,
|
|
||||||
codebook_a_sizes,
|
|
||||||
codebook_stride
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void code1x16_dequant_cuda(
|
void code1x16_dequant_cuda(
|
||||||
const void* __restrict__ A,
|
const void* __restrict__ A, void* __restrict__ C,
|
||||||
void* __restrict__ C,
|
const void* __restrict__ codebook, int prob_m, int prob_k,
|
||||||
const void* __restrict__ codebook,
|
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||||
int prob_m,
|
// codebook, at most 3 long.
|
||||||
int prob_k,
|
|
||||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
|
||||||
const int codebook_stride // as int4.
|
const int codebook_stride // as int4.
|
||||||
) {
|
) {
|
||||||
int sms;
|
int sms;
|
||||||
@@ -417,24 +367,20 @@ void code1x16_dequant_cuda(
|
|||||||
int threads = 32 * thread_m;
|
int threads = 32 * thread_m;
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
Code1x16Dequant<<<blocks, threads, 0, stream>>>(
|
Code1x16Dequant<<<blocks, threads, 0, stream>>>(
|
||||||
(const int4*) A,
|
(const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
|
||||||
(int4*) C,
|
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
|
||||||
(const int4*) codebook,
|
// most 3 long.
|
||||||
prob_m,
|
|
||||||
prob_k,
|
|
||||||
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
|
||||||
codebook_stride // as int4.
|
codebook_stride // as int4.
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dequantizes the code and codebook into weights.
|
// Dequantizes the code and codebook into weights.
|
||||||
void code2x8_dequant_cuda(
|
void code2x8_dequant_cuda(
|
||||||
const void* __restrict__ A,
|
const void* __restrict__ A, void* __restrict__ C,
|
||||||
void* __restrict__ C,
|
const void* __restrict__ codebook, int prob_m, int prob_k,
|
||||||
const void* __restrict__ codebook,
|
const int4
|
||||||
int prob_m,
|
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
|
||||||
int prob_k,
|
// most 3 long, corresponds to cols.
|
||||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
|
|
||||||
const int codebook_stride // as int4
|
const int codebook_stride // as int4
|
||||||
) {
|
) {
|
||||||
int sms;
|
int sms;
|
||||||
@@ -451,50 +397,33 @@ void code2x8_dequant_cuda(
|
|||||||
int shared = 16 * (2 * 256 * 8 + 32 * 9);
|
int shared = 16 * (2 * 256 * 8 + 32 * 9);
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
|
|
||||||
cudaFuncSetAttribute(
|
cudaFuncSetAttribute(Code2x8Dequant,
|
||||||
Code2x8Dequant, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
|
cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
|
||||||
);
|
|
||||||
Code2x8Dequant<<<blocks, threads, shared, stream>>>(
|
Code2x8Dequant<<<blocks, threads, shared, stream>>>(
|
||||||
(const int4*) A,
|
(const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
|
||||||
(int4*) C,
|
codebook_a_sizes, codebook_stride);
|
||||||
(const int4*) codebook,
|
|
||||||
prob_m,
|
|
||||||
prob_k,
|
|
||||||
codebook_a_sizes,
|
|
||||||
codebook_stride
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int codebook_stride(const torch::Tensor& codebooks)
|
int codebook_stride(const torch::Tensor& codebooks) {
|
||||||
{
|
|
||||||
return codebooks.stride(0) * codebooks.element_size() / sizeof(int4);
|
return codebooks.stride(0) * codebooks.element_size() / sizeof(int4);
|
||||||
}
|
}
|
||||||
|
|
||||||
void code1x16_matvec(
|
void code1x16_matvec(
|
||||||
const torch::Tensor& A,
|
const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C,
|
||||||
const torch::Tensor& B,
|
|
||||||
torch::Tensor& C,
|
|
||||||
const torch::Tensor& codebook,
|
const torch::Tensor& codebook,
|
||||||
const int4 codebook_a_sizes // cumulative sizes of A spanning each codebook, at most 3 long.
|
const int4 codebook_a_sizes // cumulative sizes of A spanning each
|
||||||
|
// codebook, at most 3 long.
|
||||||
) {
|
) {
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||||
int prob_m = C.size(0);
|
int prob_m = C.size(0);
|
||||||
int prob_k = B.size(0);
|
int prob_k = B.size(0);
|
||||||
|
|
||||||
code1x16_matvec_cuda(
|
code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
|
||||||
A.data_ptr(),
|
codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
|
||||||
B.data_ptr(),
|
codebook_stride(codebook));
|
||||||
C.data_ptr(),
|
|
||||||
codebook.data_ptr(),
|
|
||||||
prob_m,
|
|
||||||
prob_k,
|
|
||||||
codebook_a_sizes,
|
|
||||||
codebook_stride(codebook)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor code1x16_matmat(
|
torch::Tensor code1x16_matmat(const torch::Tensor& input,
|
||||||
const torch::Tensor& input,
|
|
||||||
const torch::Tensor& codes,
|
const torch::Tensor& codes,
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& scales,
|
const torch::Tensor& scales,
|
||||||
@@ -503,22 +432,15 @@ torch::Tensor code1x16_matmat(
|
|||||||
auto input_sizes = input.sizes();
|
auto input_sizes = input.sizes();
|
||||||
auto out_features = codes.size(0) * codebooks.size(2);
|
auto out_features = codes.size(0) * codebooks.size(2);
|
||||||
auto flat_input = input.reshape({-1, input.size(-1)});
|
auto flat_input = input.reshape({-1, input.size(-1)});
|
||||||
auto flat_output = torch::empty({flat_input.size(0), out_features},
|
auto flat_output = torch::empty(
|
||||||
torch::TensorOptions()
|
{flat_input.size(0), out_features},
|
||||||
.dtype(input.dtype())
|
torch::TensorOptions().dtype(input.dtype()).device(input.device()));
|
||||||
.device(input.device())
|
|
||||||
);
|
|
||||||
|
|
||||||
for (int i = 0; i < flat_input.size(0); ++i) {
|
for (int i = 0; i < flat_input.size(0); ++i) {
|
||||||
auto input_vec = flat_input.index({i});
|
auto input_vec = flat_input.index({i});
|
||||||
auto output_vec = flat_output.index({i});
|
auto output_vec = flat_output.index({i});
|
||||||
code1x16_matvec(
|
code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
|
||||||
codes.squeeze(2),
|
codebook_a_sizes);
|
||||||
input_vec,
|
|
||||||
output_vec,
|
|
||||||
codebooks,
|
|
||||||
codebook_a_sizes
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
flat_output *= scales.flatten().unsqueeze(0);
|
flat_output *= scales.flatten().unsqueeze(0);
|
||||||
|
|
||||||
@@ -533,55 +455,35 @@ torch::Tensor code1x16_matmat(
|
|||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
void code2x8_matvec(
|
void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B,
|
||||||
const torch::Tensor& A,
|
torch::Tensor& C, const torch::Tensor& codebook,
|
||||||
const torch::Tensor& B,
|
const int4 codebook_a_sizes) {
|
||||||
torch::Tensor& C,
|
|
||||||
const torch::Tensor& codebook,
|
|
||||||
const int4 codebook_a_sizes
|
|
||||||
) {
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||||
int prob_m = C.size(0);
|
int prob_m = C.size(0);
|
||||||
int prob_k = B.size(0);
|
int prob_k = B.size(0);
|
||||||
code2x8_matvec_cuda(
|
code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
|
||||||
A.data_ptr(),
|
codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
|
||||||
B.data_ptr(),
|
2 * codebook_stride(codebook));
|
||||||
C.data_ptr(),
|
|
||||||
codebook.data_ptr(),
|
|
||||||
prob_m,
|
|
||||||
prob_k,
|
|
||||||
codebook_a_sizes,
|
|
||||||
2 * codebook_stride(codebook)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor code2x8_matmat(
|
torch::Tensor code2x8_matmat(const torch::Tensor& input,
|
||||||
const torch::Tensor& input,
|
|
||||||
const torch::Tensor& codes,
|
const torch::Tensor& codes,
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& scales,
|
const torch::Tensor& scales,
|
||||||
const int4 codebook_a_sizes,
|
const int4 codebook_a_sizes,
|
||||||
const std::optional<torch::Tensor>& bias
|
const std::optional<torch::Tensor>& bias) {
|
||||||
) {
|
|
||||||
auto input_sizes = input.sizes();
|
auto input_sizes = input.sizes();
|
||||||
auto out_features = codes.size(0) * codebooks.size(2);
|
auto out_features = codes.size(0) * codebooks.size(2);
|
||||||
auto flat_input = input.reshape({-1, input.size(-1)});
|
auto flat_input = input.reshape({-1, input.size(-1)});
|
||||||
auto flat_output = torch::empty({flat_input.size(0), out_features},
|
auto flat_output = torch::empty(
|
||||||
torch::TensorOptions()
|
{flat_input.size(0), out_features},
|
||||||
.dtype(input.dtype())
|
torch::TensorOptions().dtype(input.dtype()).device(input.device()));
|
||||||
.device(input.device())
|
|
||||||
);
|
|
||||||
|
|
||||||
for (int i = 0; i < flat_input.size(0); ++i) {
|
for (int i = 0; i < flat_input.size(0); ++i) {
|
||||||
auto input_vec = flat_input.index({i});
|
auto input_vec = flat_input.index({i});
|
||||||
auto output_vec = flat_output.index({i});
|
auto output_vec = flat_output.index({i});
|
||||||
code2x8_matvec(
|
code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
|
||||||
codes.squeeze(2),
|
codebook_a_sizes);
|
||||||
input_vec,
|
|
||||||
output_vec,
|
|
||||||
codebooks,
|
|
||||||
codebook_a_sizes
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
flat_output *= scales.flatten().unsqueeze(0);
|
flat_output *= scales.flatten().unsqueeze(0);
|
||||||
if (bias.has_value()) {
|
if (bias.has_value()) {
|
||||||
@@ -596,21 +498,18 @@ torch::Tensor code2x8_matmat(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Accumulate the partition sizes.
|
// Accumulate the partition sizes.
|
||||||
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes)
|
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) {
|
||||||
{
|
|
||||||
int4 cumulative_sizes;
|
int4 cumulative_sizes;
|
||||||
auto cumulative_size = &cumulative_sizes.x;
|
auto cumulative_size = &cumulative_sizes.x;
|
||||||
int i = 0;
|
int i = 0;
|
||||||
int last = 0;
|
int last = 0;
|
||||||
assert(codebook_partition_sizes.size(0) <= 4);
|
assert(codebook_partition_sizes.size(0) <= 4);
|
||||||
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size)
|
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) {
|
||||||
{
|
|
||||||
*cumulative_size = codebook_partition_sizes[i].item<int>() + last;
|
*cumulative_size = codebook_partition_sizes[i].item<int>() + last;
|
||||||
last = *cumulative_size;
|
last = *cumulative_size;
|
||||||
}
|
}
|
||||||
// fill in the rest with unreachable.
|
// fill in the rest with unreachable.
|
||||||
for (; i < 4; ++i, ++cumulative_size)
|
for (; i < 4; ++i, ++cumulative_size) {
|
||||||
{
|
|
||||||
*cumulative_size = last * 10;
|
*cumulative_size = last * 10;
|
||||||
}
|
}
|
||||||
return cumulative_sizes;
|
return cumulative_sizes;
|
||||||
@@ -619,41 +518,36 @@ int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes)
|
|||||||
} // namespace aqlm
|
} // namespace aqlm
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
|
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||||
torch::Tensor aqlm_gemm(
|
|
||||||
const torch::Tensor& input,
|
|
||||||
const torch::Tensor& codes,
|
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& scales,
|
const torch::Tensor& scales,
|
||||||
const torch::Tensor& codebook_partition_sizes,
|
const torch::Tensor& codebook_partition_sizes,
|
||||||
const std::optional<torch::Tensor>& bias
|
const std::optional<torch::Tensor>& bias) {
|
||||||
)
|
int4 cumulative_sizes =
|
||||||
{
|
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||||
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
|
||||||
|
|
||||||
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
||||||
int const entries = codebooks.size(1);
|
int const entries = codebooks.size(1);
|
||||||
|
|
||||||
if (nbooks == 1 && entries == (1 << 16))
|
if (nbooks == 1 && entries == (1 << 16)) {
|
||||||
{
|
return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales,
|
||||||
return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
|
cumulative_sizes, bias);
|
||||||
}
|
}
|
||||||
if (nbooks == 2 && entries == (1 << 8))
|
if (nbooks == 2 && entries == (1 << 8)) {
|
||||||
{
|
return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales,
|
||||||
return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
|
cumulative_sizes, bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
|
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
|
||||||
|
" entries is not currently supported.")
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor aqlm_dequant(
|
torch::Tensor aqlm_dequant(const torch::Tensor& codes,
|
||||||
const torch::Tensor& codes,
|
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codebooks,
|
||||||
const torch::Tensor& codebook_partition_sizes
|
const torch::Tensor& codebook_partition_sizes) {
|
||||||
)
|
int4 cumulative_sizes =
|
||||||
{
|
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||||
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
|
||||||
|
|
||||||
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
||||||
int const entries = codebooks.size(1);
|
int const entries = codebooks.size(1);
|
||||||
@@ -670,43 +564,35 @@ torch::Tensor aqlm_dequant(
|
|||||||
auto weights = torch::empty({out_features, in_features},
|
auto weights = torch::empty({out_features, in_features},
|
||||||
torch::TensorOptions()
|
torch::TensorOptions()
|
||||||
.dtype(codebooks.dtype())
|
.dtype(codebooks.dtype())
|
||||||
.device(codebooks.device())
|
.device(codebooks.device()));
|
||||||
);
|
|
||||||
|
|
||||||
if (nbooks == 1 && entries == (1 << 16))
|
if (nbooks == 1 && entries == (1 << 16)) {
|
||||||
{
|
vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
|
||||||
vllm::aqlm::code1x16_dequant_cuda(
|
codebooks.data_ptr(), out_features,
|
||||||
codes.data_ptr(),
|
in_features, cumulative_sizes,
|
||||||
weights.data_ptr(),
|
|
||||||
codebooks.data_ptr(),
|
|
||||||
out_features,
|
|
||||||
in_features,
|
|
||||||
cumulative_sizes,
|
|
||||||
vllm::aqlm::codebook_stride(codebooks));
|
vllm::aqlm::codebook_stride(codebooks));
|
||||||
|
|
||||||
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.)
|
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower
|
||||||
// weights *= scales.index({"...", 0, 0});
|
// and not consistent with gemv implementation.) weights *=
|
||||||
|
// scales.index({"...", 0, 0});
|
||||||
|
|
||||||
return weights;
|
return weights;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (nbooks == 2 && entries == (1 << 8))
|
if (nbooks == 2 && entries == (1 << 8)) {
|
||||||
{
|
vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
|
||||||
vllm::aqlm::code2x8_dequant_cuda(
|
codebooks.data_ptr(), out_features,
|
||||||
codes.data_ptr(),
|
in_features, cumulative_sizes,
|
||||||
weights.data_ptr(),
|
|
||||||
codebooks.data_ptr(),
|
|
||||||
out_features,
|
|
||||||
in_features,
|
|
||||||
cumulative_sizes,
|
|
||||||
vllm::aqlm::codebook_stride(codebooks));
|
vllm::aqlm::codebook_stride(codebooks));
|
||||||
|
|
||||||
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation)
|
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower
|
||||||
// weights *= scales.index({"...", 0, 0});
|
// and not consistent with gemv implementation) weights *=
|
||||||
|
// scales.index({"...", 0, 0});
|
||||||
|
|
||||||
return weights;
|
return weights;
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
|
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
|
||||||
|
" entries is not currently supported.")
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
/*
|
/*
|
||||||
Adapted from https://github.com/mit-han-lab/llm-awq
|
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||||
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
Modified from NVIDIA FasterTransformer:
|
||||||
|
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||||
@article{lin2023awq,
|
@article{lin2023awq,
|
||||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
title={AWQ: Activation-aware Weight Quantization for LLM Compression and
|
||||||
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
|
||||||
journal={arXiv},
|
Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
|
||||||
year={2023}
|
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
@@ -14,8 +14,7 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
namespace awq {
|
namespace awq {
|
||||||
|
|
||||||
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
|
||||||
{
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||||
assert(false);
|
assert(false);
|
||||||
#else
|
#else
|
||||||
@@ -30,33 +29,40 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
|||||||
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
||||||
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||||
|
|
||||||
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
|
// Note that the entire sequence only requires 1 shift instruction. This is
|
||||||
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
|
// thanks to the register packing format and the fact that we force our
|
||||||
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
|
// integers to be unsigned, and account for this in the fp16 subtractions. In
|
||||||
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
|
// addition, I exploit the fact that sub and fma have the same throughput in
|
||||||
|
// order to convert elt_23 and elt_67 to fp16 without having to shift them to
|
||||||
|
// the bottom bits before hand.
|
||||||
|
|
||||||
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
|
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
|
||||||
// immediately before required.
|
// dependency if we issue immediately before required.
|
||||||
const uint32_t top_i4s = i4s >> 8;
|
const uint32_t top_i4s = i4s >> 8;
|
||||||
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
||||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
: "=r"(h[0])
|
: "=r"(h[0])
|
||||||
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
|
||||||
|
"n"(immLut));
|
||||||
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
||||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
: "=r"(h[1])
|
: "=r"(h[1])
|
||||||
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
|
||||||
|
"n"(immLut));
|
||||||
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
||||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
: "=r"(h[2])
|
: "=r"(h[2])
|
||||||
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
|
||||||
|
"n"(immLut));
|
||||||
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
||||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
: "=r"(h[3])
|
: "=r"(h[3])
|
||||||
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
|
||||||
|
"n"(immLut));
|
||||||
|
|
||||||
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
|
// I use inline PTX below because I am not sure if the compiler will emit
|
||||||
// half2 ctor. In this case, I chose performance reliability over code readability.
|
// float2half instructions if I use the half2 ctor. In this case, I chose
|
||||||
|
// performance reliability over code readability.
|
||||||
|
|
||||||
// This is the half2 {1032, 1032} represented as an integer.
|
// This is the half2 {1032, 1032} represented as an integer.
|
||||||
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
||||||
@@ -71,13 +77,21 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
|||||||
|
|
||||||
// Finally, we construct the output numbers.
|
// Finally, we construct the output numbers.
|
||||||
// Convert elt_01
|
// Convert elt_01
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
|
: "=r"(h[0])
|
||||||
|
: "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
||||||
// Convert elt_23
|
// Convert elt_23
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(h[1])
|
||||||
|
: "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||||
// Convert elt_45
|
// Convert elt_45
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
|
: "=r"(h[2])
|
||||||
|
: "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
||||||
// Convert elt_67
|
// Convert elt_67
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(h[3])
|
||||||
|
: "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -1,15 +1,13 @@
|
|||||||
/*
|
/*
|
||||||
Adapted from https://github.com/mit-han-lab/llm-awq
|
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||||
@article{lin2023awq,
|
@article{lin2023awq,
|
||||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
title={AWQ: Activation-aware Weight Quantization for LLM Compression and
|
||||||
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
|
||||||
journal={arXiv},
|
Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
|
||||||
year={2023}
|
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <torch/all.h>
|
||||||
#include <torch/extension.h>
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
#include "dequantize.cuh"
|
#include "dequantize.cuh"
|
||||||
@@ -20,26 +18,20 @@ namespace vllm {
|
|||||||
namespace awq {
|
namespace awq {
|
||||||
|
|
||||||
// Pack two half values.
|
// Pack two half values.
|
||||||
static inline __device__ __host__ unsigned
|
static inline __device__ __host__ unsigned __pack_half2(const half x,
|
||||||
__pack_half2(const half x, const half y) {
|
const half y) {
|
||||||
unsigned v0 = *((unsigned short*)&x);
|
unsigned v0 = *((unsigned short*)&x);
|
||||||
unsigned v1 = *((unsigned short*)&y);
|
unsigned v1 = *((unsigned short*)&y);
|
||||||
return (v1 << 16) | v0;
|
return (v1 << 16) | v0;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int N>
|
template <int N>
|
||||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
__global__ void __launch_bounds__(64)
|
||||||
int G,
|
gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters,
|
||||||
int split_k_iters,
|
half* __restrict__ A, int* __restrict__ B,
|
||||||
half* __restrict__ A,
|
|
||||||
int* __restrict__ B,
|
|
||||||
half* __restrict__ scaling_factors,
|
half* __restrict__ scaling_factors,
|
||||||
int* __restrict__ zeros,
|
int* __restrict__ zeros, int M, int IC,
|
||||||
int M,
|
int OC, half* __restrict__ C) {
|
||||||
int IC,
|
|
||||||
int OC,
|
|
||||||
half* __restrict__ C)
|
|
||||||
{
|
|
||||||
// Only support matrix n = 64 or 128
|
// Only support matrix n = 64 or 128
|
||||||
assert(N == 64 || N == 128);
|
assert(N == 64 || N == 128);
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||||
@@ -70,43 +62,46 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
|||||||
static constexpr int row_stride = 2 * 32 * 8 / N;
|
static constexpr int row_stride = 2 * 32 * 8 / N;
|
||||||
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
|
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
|
||||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
bool ld_A_flag =
|
||||||
|
(blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp +
|
||||||
|
threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||||
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||||
|
|
||||||
half* A_ptr = A
|
half* A_ptr =
|
||||||
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
A +
|
||||||
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
(((int)blockIdx_y) / j_factors1 * 16 +
|
||||||
|
(((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) *
|
||||||
|
IC +
|
||||||
|
(((int)threadIdx.x) % (32 / 8)) * 8;
|
||||||
|
|
||||||
int* B_ptr = B
|
int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) +
|
||||||
+ ((int)threadIdx.y) * (OC / 8) * (256 / N)
|
(((int)threadIdx.x) / (N / 8)) * (OC / 8) +
|
||||||
+ (((int)threadIdx.x) / (N / 8)) * (OC / 8)
|
(((int)blockIdx_y) % j_factors1) * (N / 8) +
|
||||||
+ (((int)blockIdx_y) % j_factors1) * (N / 8)
|
(((int)threadIdx.x) % (N / 8)) * 1;
|
||||||
+ (((int)threadIdx.x) % (N / 8)) * 1;
|
|
||||||
// Why * 1 in the above line?
|
// Why * 1 in the above line?
|
||||||
|
|
||||||
half* A_shared_ptr = A_shared
|
half* A_shared_ptr = A_shared +
|
||||||
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
((int)threadIdx.y) * row_stride_warp * (32 + 8) +
|
||||||
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
(((int)threadIdx.x) / (32 / 8)) * (32 + 8) +
|
||||||
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
(((int)threadIdx.x) % (32 / 8)) * 8;
|
||||||
|
|
||||||
half* B_shared_ptr = B_shared
|
half* B_shared_ptr = B_shared +
|
||||||
+ ((int)threadIdx.y) * (row_stride / 2) * (N + 8)
|
((int)threadIdx.y) * (row_stride / 2) * (N + 8) +
|
||||||
+ (((int)threadIdx.x) / (N / 8)) * (N + 8)
|
(((int)threadIdx.x) / (N / 8)) * (N + 8) +
|
||||||
+ (((int)threadIdx.x) % (N / 8)) * 8;
|
(((int)threadIdx.x) % (N / 8)) * 8;
|
||||||
|
|
||||||
int* zeros_ptr = zeros
|
int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) +
|
||||||
+ (((int)blockIdx_y) % j_factors1) * (N / 8)
|
((int)threadIdx.x) % (N / 8);
|
||||||
+ ((int)threadIdx.x) % (N / 8);
|
|
||||||
|
|
||||||
half* scaling_factors_ptr = scaling_factors
|
half* scaling_factors_ptr = scaling_factors +
|
||||||
+ (((int)blockIdx_y) % j_factors1) * N
|
(((int)blockIdx_y) % j_factors1) * N +
|
||||||
+ (((int)threadIdx.x) % (N / 8)) * 8;
|
(((int)threadIdx.x) % (N / 8)) * 8;
|
||||||
|
|
||||||
half* C_ptr = C
|
half* C_ptr =
|
||||||
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
C +
|
||||||
+ (((int)blockIdx_y) % j_factors1) * N
|
static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
||||||
+ ((int)threadIdx.y) * (N / 2)
|
+ (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) +
|
||||||
+ (((int)threadIdx.x) % 4) * 2;
|
(((int)threadIdx.x) % 4) * 2;
|
||||||
|
|
||||||
// preload s.f. and zeros
|
// preload s.f. and zeros
|
||||||
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
||||||
@@ -115,57 +110,83 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
|||||||
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
if (ld_A_flag)
|
if (ld_A_flag) {
|
||||||
{
|
|
||||||
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
||||||
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
||||||
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||||
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
uint4 B_loaded_scale =
|
||||||
|
*(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||||
/*
|
/*
|
||||||
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 &&
|
||||||
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x,
|
||||||
|
B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x,
|
||||||
|
B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||||
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||||
|
|
||||||
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
|
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
|
||||||
|
|
||||||
// B: 32 x 136 (128+8) float16
|
// B: 32 x 136 (128+8) float16
|
||||||
// each warp: 32 x 4
|
// each warp: 32 x 4
|
||||||
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus
|
||||||
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
// zero -> WB UINT4
|
||||||
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) *
|
||||||
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
// 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15)
|
||||||
|
// * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 *
|
||||||
|
// 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) *
|
||||||
|
// 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) *
|
||||||
|
// 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||||
|
uint32_t B_loaded =
|
||||||
|
*(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||||
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
// uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N /
|
||||||
|
// 8)) * 8);
|
||||||
|
|
||||||
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x
|
||||||
|
// % (cta_N / 8)) * 8);
|
||||||
// - zero and * scale
|
// - zero and * scale
|
||||||
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
// q * scale - zero * scale.
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
: "=r"(B_loaded_fp16.x)
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
: "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
: "=r"(B_loaded_fp16.x)
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
: "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
|
: "=r"(B_loaded_fp16.y)
|
||||||
|
: "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(B_loaded_fp16.y)
|
||||||
|
: "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
|
: "=r"(B_loaded_fp16.z)
|
||||||
|
: "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(B_loaded_fp16.z)
|
||||||
|
: "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
|
: "=r"(B_loaded_fp16.w)
|
||||||
|
: "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(B_loaded_fp16.w)
|
||||||
|
: "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||||
/*
|
/*
|
||||||
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 ==
|
||||||
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n",
|
||||||
|
B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// write back
|
// write back
|
||||||
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16;
|
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) =
|
||||||
|
B_loaded_fp16;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
@@ -173,34 +194,43 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
|||||||
{
|
{
|
||||||
unsigned int addr;
|
unsigned int addr;
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
|
||||||
|
"addr; }\n"
|
||||||
: "=r"(addr)
|
: "=r"(addr)
|
||||||
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
: "l"((void*)((&(A_shared[(k_0_1 * 16)])) +
|
||||||
);
|
(((((int)threadIdx.x) & 15) * 40) +
|
||||||
|
((((int)threadIdx.x) >> 4) * 8)))));
|
||||||
|
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||||
"{%0, %1, %2, %3}, [%4];\n"
|
"{%0, %1, %2, %3}, [%4];\n"
|
||||||
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
: "=r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||||
: "r"(addr)
|
"=r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||||
);
|
"=r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||||
|
"=r"(((unsigned*)(A_shared_warp + 0))[3])
|
||||||
|
: "r"(addr));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
|
for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
|
||||||
{
|
{
|
||||||
unsigned int addr;
|
unsigned int addr;
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
|
||||||
|
"addr; }\n"
|
||||||
: "=r"(addr)
|
: "=r"(addr)
|
||||||
: "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8))))
|
: "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) +
|
||||||
);
|
(((int)threadIdx.y) * (N / 2))) +
|
||||||
|
(ax1_0 * 16))])) +
|
||||||
|
(((((int)threadIdx.x) & 15) * (N + 8)) +
|
||||||
|
((((int)threadIdx.x) >> 4) * 8)))));
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||||
"{%0, %1, %2, %3}, [%4];\n"
|
"{%0, %1, %2, %3}, [%4];\n"
|
||||||
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
: "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]),
|
||||||
: "r"(addr)
|
"=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]),
|
||||||
);
|
"=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]),
|
||||||
|
"=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||||
|
: "r"(addr));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
|
for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
|
||||||
@@ -209,48 +239,110 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
|||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
: "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||||
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
: "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||||
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
: "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||||
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[3]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
: "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||||
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[3]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
{
|
{
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
|
||||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
"%13};\n"
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
: "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||||
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||||
|
"=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[3]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
|
||||||
|
"f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
__asm__ __volatile__(
|
__asm__ __volatile__(
|
||||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
|
||||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
"%13};\n"
|
||||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
: "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||||
|
"=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned*)(A_shared_warp + 0))[0]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[1]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[2]),
|
||||||
|
"r"(((unsigned*)(A_shared_warp + 0))[3]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
|
"r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
|
||||||
|
"f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
@@ -261,24 +353,20 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
|||||||
// TODO: Shang: Hoist loop invariance.
|
// TODO: Shang: Hoist loop invariance.
|
||||||
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
|
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
|
||||||
for (int local_id = 0; local_id < 8; ++local_id) {
|
for (int local_id = 0; local_id < 8; ++local_id) {
|
||||||
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 +
|
||||||
if (row_offset < M)
|
((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||||
{
|
if (row_offset < M) {
|
||||||
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 +
|
||||||
|
local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void __launch_bounds__(64) dequantize_weights(
|
__global__ void __launch_bounds__(64)
|
||||||
int* __restrict__ B,
|
dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors,
|
||||||
half* __restrict__ scaling_factors,
|
int* __restrict__ zeros, half* __restrict__ C, int G) {
|
||||||
int* __restrict__ zeros,
|
|
||||||
half* __restrict__ C,
|
|
||||||
int G
|
|
||||||
)
|
|
||||||
{
|
|
||||||
int j_factors1 = 4;
|
int j_factors1 = 4;
|
||||||
int row_stride2 = 4;
|
int row_stride2 = 4;
|
||||||
int split_k_iters = 1;
|
int split_k_iters = 1;
|
||||||
@@ -310,14 +398,30 @@ __global__ void __launch_bounds__(64) dequantize_weights(
|
|||||||
|
|
||||||
uint32_t B_loaded = *(uint32_t*)B_ptr2;
|
uint32_t B_loaded = *(uint32_t*)B_ptr2;
|
||||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
: "=r"(B_loaded_fp16.x)
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
: "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
: "=r"(B_loaded_fp16.x)
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
: "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
: "=r"(B_loaded_fp16.y)
|
||||||
|
: "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(B_loaded_fp16.y)
|
||||||
|
: "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
|
: "=r"(B_loaded_fp16.z)
|
||||||
|
: "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(B_loaded_fp16.z)
|
||||||
|
: "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n"
|
||||||
|
: "=r"(B_loaded_fp16.w)
|
||||||
|
: "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(B_loaded_fp16.w)
|
||||||
|
: "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||||
|
|
||||||
*(uint4*)B_shared_ptr2 = B_loaded_fp16;
|
*(uint4*)B_shared_ptr2 = B_loaded_fp16;
|
||||||
|
|
||||||
@@ -329,14 +433,10 @@ __global__ void __launch_bounds__(64) dequantize_weights(
|
|||||||
} // namespace awq
|
} // namespace awq
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
torch::Tensor awq_dequantize(
|
torch::Tensor awq_dequantize(torch::Tensor _kernel,
|
||||||
torch::Tensor _kernel,
|
|
||||||
torch::Tensor _scaling_factors,
|
torch::Tensor _scaling_factors,
|
||||||
torch::Tensor _zeros,
|
torch::Tensor _zeros, int64_t split_k_iters,
|
||||||
int split_k_iters,
|
int64_t thx, int64_t thy) {
|
||||||
int thx,
|
|
||||||
int thy)
|
|
||||||
{
|
|
||||||
int in_c = _kernel.size(0);
|
int in_c = _kernel.size(0);
|
||||||
int qout_c = _kernel.size(1);
|
int qout_c = _kernel.size(1);
|
||||||
int out_c = qout_c * 8;
|
int out_c = qout_c * 8;
|
||||||
@@ -362,12 +462,15 @@ torch::Tensor awq_dequantize(
|
|||||||
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
|
||||||
|
|
||||||
auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device());
|
auto options = torch::TensorOptions()
|
||||||
|
.dtype(_scaling_factors.dtype())
|
||||||
|
.device(_scaling_factors.device());
|
||||||
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);
|
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);
|
||||||
|
|
||||||
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||||
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
|
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
|
||||||
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
auto scaling_factors =
|
||||||
|
reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||||
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||||
|
|
||||||
dim3 num_blocks(x_blocks, y_blocks);
|
dim3 num_blocks(x_blocks, y_blocks);
|
||||||
@@ -386,26 +489,26 @@ torch::Tensor awq_dequantize(
|
|||||||
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
|
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
|
||||||
// assume that batch_size < 16 for now
|
// assume that batch_size < 16 for now
|
||||||
|
|
||||||
torch::Tensor awq_gemm(
|
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
||||||
torch::Tensor _in_feats,
|
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
||||||
torch::Tensor _kernel,
|
int64_t split_k_iters) {
|
||||||
torch::Tensor _scaling_factors,
|
|
||||||
torch::Tensor _zeros,
|
|
||||||
int split_k_iters)
|
|
||||||
{
|
|
||||||
int num_in_feats = _in_feats.size(0);
|
int num_in_feats = _in_feats.size(0);
|
||||||
int num_in_channels = _in_feats.size(1);
|
int num_in_channels = _in_feats.size(1);
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
|
||||||
|
|
||||||
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
|
auto options = torch::TensorOptions()
|
||||||
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
|
.dtype(_in_feats.dtype())
|
||||||
|
.device(_in_feats.device());
|
||||||
|
at::Tensor _out_feats =
|
||||||
|
torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
|
||||||
int num_out_feats = _out_feats.size(-2);
|
int num_out_feats = _out_feats.size(-2);
|
||||||
int num_out_channels = _out_feats.size(-1);
|
int num_out_channels = _out_feats.size(-1);
|
||||||
|
|
||||||
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
|
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
|
||||||
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||||
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
|
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
|
||||||
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
auto scaling_factors =
|
||||||
|
reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||||
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||||
int group_size = num_in_channels / _scaling_factors.size(0);
|
int group_size = num_in_channels / _scaling_factors.size(0);
|
||||||
|
|
||||||
@@ -419,28 +522,28 @@ torch::Tensor awq_gemm(
|
|||||||
throw std::invalid_argument("OC is not multiple of Group size");
|
throw std::invalid_argument("OC is not multiple of Group size");
|
||||||
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
if (num_out_channels % 128 == 0)
|
if (num_out_channels % 128 == 0) {
|
||||||
{
|
|
||||||
int j_factors1 = num_out_channels / 128 / 1;
|
int j_factors1 = num_out_channels / 128 / 1;
|
||||||
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||||
// threadIdx.x: 32
|
// threadIdx.x: 32
|
||||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||||
dim3 threads_per_block(32, 2);
|
dim3 threads_per_block(32, 2);
|
||||||
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<<num_blocks, threads_per_block, 0, stream>>>(
|
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128>
|
||||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
|
<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||||
num_out_channels, out_feats);
|
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
|
||||||
}
|
num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||||
else if (num_out_channels % 64 == 0)
|
} else if (num_out_channels % 64 == 0) {
|
||||||
{
|
|
||||||
int j_factors1 = num_out_channels / 64 / 1;
|
int j_factors1 = num_out_channels / 64 / 1;
|
||||||
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 *
|
||||||
|
split_k_iters);
|
||||||
|
|
||||||
// threadIdx.x: 32
|
// threadIdx.x: 32
|
||||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||||
dim3 threads_per_block(32, 2);
|
dim3 threads_per_block(32, 2);
|
||||||
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<<num_blocks, threads_per_block, 0, stream>>>(
|
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64>
|
||||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
|
<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||||
num_out_channels, out_feats);
|
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
|
||||||
|
num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||||
}
|
}
|
||||||
return _out_feats.sum(0);
|
return _out_feats.sum(0);
|
||||||
}
|
}
|
||||||
|
|||||||
115
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
Normal file
115
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "../../dispatch_utils.h"
|
||||||
|
#include "../../reduction_utils.cuh"
|
||||||
|
|
||||||
|
static inline __device__ int8_t float_to_int8_rn(float x) {
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
static const float i8_min =
|
||||||
|
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||||
|
static const float i8_max =
|
||||||
|
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||||
|
// round
|
||||||
|
float dst = std::nearbyint(x);
|
||||||
|
// saturate
|
||||||
|
dst = std::clamp(dst, i8_min, i8_max);
|
||||||
|
return static_cast<int8_t>(dst);
|
||||||
|
#else
|
||||||
|
// CUDA path
|
||||||
|
uint32_t dst;
|
||||||
|
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
|
||||||
|
return reinterpret_cast<const int8_t&>(dst);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
template <typename scalar_t, typename scale_type>
|
||||||
|
__global__ void static_scaled_int8_quant_kernel(
|
||||||
|
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
||||||
|
scale_type const* scale_ptr, const int hidden_size) {
|
||||||
|
int const tid = threadIdx.x;
|
||||||
|
int const token_idx = blockIdx.x;
|
||||||
|
scale_type const scale = *scale_ptr;
|
||||||
|
|
||||||
|
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||||
|
out[token_idx * hidden_size + i] = float_to_int8_rn(
|
||||||
|
static_cast<float>(input[token_idx * hidden_size + i]) / scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, typename scale_type>
|
||||||
|
__global__ void dynamic_scaled_int8_quant_kernel(
|
||||||
|
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
||||||
|
scale_type* scale, const int hidden_size) {
|
||||||
|
int const tid = threadIdx.x;
|
||||||
|
int const token_idx = blockIdx.x;
|
||||||
|
float absmax_val = 0.0f;
|
||||||
|
float const zero = 0.0f;
|
||||||
|
|
||||||
|
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||||
|
float val = static_cast<float>(input[token_idx * hidden_size + i]);
|
||||||
|
val = val > zero ? val : -val;
|
||||||
|
absmax_val = val > absmax_val ? val : absmax_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
|
||||||
|
__shared__ float block_absmax_val;
|
||||||
|
if (tid == 0) {
|
||||||
|
block_absmax_val = block_absmax_val_maybe;
|
||||||
|
scale[token_idx] = block_absmax_val / 127.0f;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float const tmp_scale = 127.0f / block_absmax_val;
|
||||||
|
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||||
|
out[token_idx * hidden_size + i] = float_to_int8_rn(
|
||||||
|
static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||||
|
torch::Tensor const& input, // [..., hidden_size]
|
||||||
|
torch::Tensor const& scale) {
|
||||||
|
TORCH_CHECK(input.is_contiguous());
|
||||||
|
TORCH_CHECK(out.is_contiguous());
|
||||||
|
TORCH_CHECK(scale.numel() == 1);
|
||||||
|
|
||||||
|
int const hidden_size = input.size(-1);
|
||||||
|
int const num_tokens = input.numel() / hidden_size;
|
||||||
|
dim3 const grid(num_tokens);
|
||||||
|
dim3 const block(std::min(hidden_size, 1024));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
|
||||||
|
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
|
||||||
|
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
|
||||||
|
out.data_ptr<int8_t>(),
|
||||||
|
scale.data_ptr<float>(), hidden_size);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void dynamic_scaled_int8_quant(
|
||||||
|
torch::Tensor& out, // [..., hidden_size]
|
||||||
|
torch::Tensor const& input, // [..., hidden_size]
|
||||||
|
torch::Tensor& scales) {
|
||||||
|
TORCH_CHECK(input.is_contiguous());
|
||||||
|
TORCH_CHECK(out.is_contiguous());
|
||||||
|
|
||||||
|
int const hidden_size = input.size(-1);
|
||||||
|
int const num_tokens = input.numel() / hidden_size;
|
||||||
|
dim3 const grid(num_tokens);
|
||||||
|
dim3 const block(std::min(hidden_size, 1024));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
|
||||||
|
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
|
||||||
|
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
|
||||||
|
out.data_ptr<int8_t>(),
|
||||||
|
scales.data_ptr<float>(), hidden_size);
|
||||||
|
});
|
||||||
|
}
|
||||||
346
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp
Normal file
346
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||||
|
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice,
|
||||||
|
*this list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||||
|
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||||
|
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||||
|
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||||
|
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||||
|
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||||
|
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||||
|
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||||
|
*POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// This file is a modified excerpt of
|
||||||
|
// include/cutlass/epilogue/fusion/visitor_load.hpp from
|
||||||
|
// https://github.com/NVIDIA/cutlass v3.5.0
|
||||||
|
// It has been modified to support either
|
||||||
|
// row/column or scalar broadcasting where the tensor being loaded from is
|
||||||
|
// always passed in via a device pointer. This lets one compiled kernel handle
|
||||||
|
// all cases of per-tensor or per-channel/per-token quantization.
|
||||||
|
//
|
||||||
|
// This interface also allows the scales to be passed in as tensors that
|
||||||
|
// consistently reside on the device, which avoids an issue with a previous
|
||||||
|
// implementation where scalars needed to be on the CPU since they
|
||||||
|
// were passed in via float values. This created a potential performance hazard
|
||||||
|
// if scales were initially on the device, and caused torch.compile graph
|
||||||
|
// breaks when moving scales to the CPU.
|
||||||
|
//
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
// Turn off clang-format for the entire file to keep it close to upstream
|
||||||
|
// clang-format off
|
||||||
|
|
||||||
|
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
|
||||||
|
namespace cutlass::epilogue::threadblock {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
using namespace detail;
|
||||||
|
|
||||||
|
template<
|
||||||
|
class ThreadMap,
|
||||||
|
class Element,
|
||||||
|
class StrideMNL
|
||||||
|
>
|
||||||
|
struct VisitorRowOrScalarBroadcast {
|
||||||
|
|
||||||
|
// This struct has been modified to have a bool indicating that ptr_row is a
|
||||||
|
// scalar that must be broadcast.
|
||||||
|
struct Arguments {
|
||||||
|
Element const* ptr_row = nullptr;
|
||||||
|
bool row_broadcast = true;
|
||||||
|
StrideMNL dRow = {};
|
||||||
|
};
|
||||||
|
|
||||||
|
using Params = Arguments;
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static constexpr Params
|
||||||
|
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static size_t
|
||||||
|
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SharedStorage {};
|
||||||
|
|
||||||
|
// Global load type
|
||||||
|
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
|
||||||
|
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
|
||||||
|
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
VisitorRowOrScalarBroadcast() { }
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||||
|
: params_ptr(¶ms) { }
|
||||||
|
|
||||||
|
Params const* params_ptr;
|
||||||
|
|
||||||
|
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||||
|
struct Callbacks : EmptyCallbacks {
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
Callbacks(
|
||||||
|
GTensor&& tC_gRow,
|
||||||
|
RTensor&& tC_rRow,
|
||||||
|
CTensor&& tC_cRow,
|
||||||
|
ProblemShape problem_shape,
|
||||||
|
Params const* params_ptr
|
||||||
|
):
|
||||||
|
tC_gRow(cute::forward<GTensor>(tC_gRow)),
|
||||||
|
tC_rRow(cute::forward<RTensor>(tC_rRow)),
|
||||||
|
tC_cRow(cute::forward<CTensor>(tC_cRow)),
|
||||||
|
n(get<1>(problem_shape)),
|
||||||
|
params_ptr(params_ptr) { }
|
||||||
|
|
||||||
|
GTensor tC_gRow;
|
||||||
|
RTensor tC_rRow;
|
||||||
|
CTensor tC_cRow;
|
||||||
|
Params const* params_ptr;
|
||||||
|
int n;
|
||||||
|
|
||||||
|
// This function is modified from VisitorRowBroadcast
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
begin_epilogue() {
|
||||||
|
clear(tC_rRow);
|
||||||
|
auto src_v = filter(tC_gRow);
|
||||||
|
auto coord_v = filter(tC_cRow);
|
||||||
|
auto dst_v = filter(tC_rRow);
|
||||||
|
|
||||||
|
if (params_ptr->row_broadcast) {
|
||||||
|
// In this case we are loading from a row vector and broadcasting
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(src_v); ++i) {
|
||||||
|
bool guard = get<1>(coord_v(i)) < n;
|
||||||
|
cutlass::arch::global_load<VecType, sizeof(VecType)>(
|
||||||
|
dst_v(i), (void const*)&src_v(i), guard);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// In this case we are loading from a scalar and broadcasting
|
||||||
|
VecType filled_vec;
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < VecLength; i++) {
|
||||||
|
reinterpret_cast<Element*>(&filled_vec)[i] = *(params_ptr->ptr_row);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(src_v); ++i) {
|
||||||
|
if (get<1>(coord_v(i)) < n) {
|
||||||
|
dst_v(i) = filled_vec;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ElementAccumulator, int FragmentSize>
|
||||||
|
CUTLASS_DEVICE auto // returns an Array
|
||||||
|
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
|
||||||
|
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
|
||||||
|
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
|
||||||
|
return rRow_frg(column_idx);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
CUTLASS_DEVICE auto
|
||||||
|
get_callbacks(
|
||||||
|
gemm::GemmCoord threadblock_tile_offset,
|
||||||
|
int thread_idx,
|
||||||
|
ProblemShape problem_shape
|
||||||
|
) {
|
||||||
|
Tensor mRow = make_tensor(
|
||||||
|
make_gmem_ptr(params_ptr->ptr_row),
|
||||||
|
problem_shape,
|
||||||
|
params_ptr->dRow);
|
||||||
|
|
||||||
|
// VECTOR, FRAGMENT_COLUMN
|
||||||
|
Tensor tC_gRow = recast<VecType>(
|
||||||
|
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
|
||||||
|
)(_,_,_0{},_0{},_0{},_0{});
|
||||||
|
Tensor tC_rRow = make_tensor_like(tC_gRow);
|
||||||
|
|
||||||
|
// Generate the pred tensor
|
||||||
|
Tensor cRow = make_identity_tensor(mRow.shape());
|
||||||
|
Tensor tC_cRow = outer_partition(
|
||||||
|
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
|
||||||
|
Shape<Int<VecLength>>{},
|
||||||
|
(_0{})
|
||||||
|
);
|
||||||
|
|
||||||
|
return Callbacks<
|
||||||
|
decltype(tC_gRow), decltype(tC_rRow),
|
||||||
|
decltype(tC_cRow), ProblemShape>(
|
||||||
|
cute::move(tC_gRow),
|
||||||
|
cute::move(tC_rRow),
|
||||||
|
cute::move(tC_cRow),
|
||||||
|
problem_shape,
|
||||||
|
params_ptr
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// Column vector broadcast
|
||||||
|
template<
|
||||||
|
class ThreadMap,
|
||||||
|
class Element,
|
||||||
|
class StrideMNL = Stride<_1,_0,_0>
|
||||||
|
>
|
||||||
|
struct VisitorColOrScalarBroadcast {
|
||||||
|
|
||||||
|
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||||
|
// scalar that must be broadcast.
|
||||||
|
struct Arguments {
|
||||||
|
Element const* ptr_col = nullptr;
|
||||||
|
bool col_broadcast = true;
|
||||||
|
StrideMNL dCol = {};
|
||||||
|
};
|
||||||
|
|
||||||
|
using Params = Arguments;
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static constexpr Params
|
||||||
|
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static size_t
|
||||||
|
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SharedStorage { };
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
VisitorColOrScalarBroadcast() { }
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||||
|
: params_ptr(¶ms) { }
|
||||||
|
|
||||||
|
Params const* params_ptr;
|
||||||
|
|
||||||
|
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||||
|
struct Callbacks : EmptyCallbacks {
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
Callbacks(
|
||||||
|
GTensor&& tC_gCol,
|
||||||
|
RTensor&& tC_rCol,
|
||||||
|
CTensor&& tC_cCol,
|
||||||
|
ProblemShape problem_shape,
|
||||||
|
Params const* params_ptr
|
||||||
|
):
|
||||||
|
tC_gCol(cute::forward<GTensor>(tC_gCol)),
|
||||||
|
tC_rCol(cute::forward<RTensor>(tC_rCol)),
|
||||||
|
tC_cCol(cute::forward<CTensor>(tC_cCol)),
|
||||||
|
m(get<0>(problem_shape)),
|
||||||
|
params_ptr(params_ptr) { }
|
||||||
|
|
||||||
|
GTensor tC_gCol;
|
||||||
|
RTensor tC_rCol;
|
||||||
|
CTensor tC_cCol;
|
||||||
|
Params const* params_ptr;
|
||||||
|
int m;
|
||||||
|
|
||||||
|
// This function is modified from VisitorColBroadcast
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
begin_epilogue() {
|
||||||
|
clear(tC_rCol);
|
||||||
|
|
||||||
|
Tensor pred = make_tensor<bool>(shape(tC_gCol));
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(pred); ++i) {
|
||||||
|
pred(i) = get<0>(tC_cCol(i)) < m;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params_ptr->col_broadcast) {
|
||||||
|
// In this case we are loading from a column vector and broadcasting
|
||||||
|
copy_if(pred, tC_gCol, tC_rCol);
|
||||||
|
} else {
|
||||||
|
// In this case we are loading from a scalar and broadcasting
|
||||||
|
auto dst_v = filter(tC_rCol);
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(dst_v); ++i) {
|
||||||
|
if (pred(i)) {
|
||||||
|
dst_v(i) = *(params_ptr->ptr_col);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ElementAccumulator, int FragmentSize>
|
||||||
|
CUTLASS_DEVICE auto // returns an Array
|
||||||
|
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
|
||||||
|
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
|
||||||
|
Array<Element, FragmentSize> frg_col;
|
||||||
|
frg_col.fill(tC_rCol(row_idx,iter_idx));
|
||||||
|
return frg_col;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
CUTLASS_DEVICE auto
|
||||||
|
get_callbacks(
|
||||||
|
gemm::GemmCoord threadblock_tile_offset,
|
||||||
|
int thread_idx,
|
||||||
|
ProblemShape problem_shape
|
||||||
|
) {
|
||||||
|
Tensor mCol = make_tensor(
|
||||||
|
make_gmem_ptr(params_ptr->ptr_col),
|
||||||
|
problem_shape,
|
||||||
|
params_ptr->dCol);
|
||||||
|
|
||||||
|
// VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER
|
||||||
|
Tensor tC_gCol = group_modes<1,4>(
|
||||||
|
ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
|
||||||
|
Tensor tC_rCol = make_tensor_like(tC_gCol);
|
||||||
|
|
||||||
|
// Generate the pred tensor
|
||||||
|
Tensor cCol = make_identity_tensor(mCol.shape());
|
||||||
|
Tensor tC_cCol = group_modes<1,4>(
|
||||||
|
ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
|
||||||
|
|
||||||
|
return Callbacks<
|
||||||
|
decltype(tC_gCol), decltype(tC_rCol),
|
||||||
|
decltype(tC_cCol), ProblemShape>(
|
||||||
|
cute::move(tC_gCol),
|
||||||
|
cute::move(tC_rCol),
|
||||||
|
cute::move(tC_cCol),
|
||||||
|
problem_shape,
|
||||||
|
params_ptr
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
389
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
Normal file
389
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
Normal file
@@ -0,0 +1,389 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||||
|
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice,
|
||||||
|
*this list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||||
|
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||||
|
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||||
|
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||||
|
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||||
|
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||||
|
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||||
|
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||||
|
*POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// This file is a modified excerpt of
|
||||||
|
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
|
||||||
|
// from https://github.com/NVIDIA/cutlass v3.5.0
|
||||||
|
// It has been modified to support either row/column or scalar broadcasting
|
||||||
|
// where the tensor being loaded from is always passed in via a device pointer.
|
||||||
|
// This lets one compiled kernel handle all cases of per-tensor or
|
||||||
|
// per-channel/per-token quantization.
|
||||||
|
//
|
||||||
|
// This interface also allows the scales to be passed in as tensors that
|
||||||
|
// consistently reside on the device, which avoids an issue with a previous
|
||||||
|
// implementation where scalars needed to be on the CPU since they
|
||||||
|
// were passed in via float values. This created a potential performance hazard
|
||||||
|
// if scales were initially on the device, and caused torch.compile graphs
|
||||||
|
// breaks when moving scales to the CPU.
|
||||||
|
//
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
// Turn off clang-format for the entire file to keep it close to upstream
|
||||||
|
// clang-format off
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/arch/barrier.h"
|
||||||
|
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
|
||||||
|
|
||||||
|
namespace cutlass::epilogue::fusion {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
using namespace detail;
|
||||||
|
|
||||||
|
// Row vector broadcast
|
||||||
|
template<
|
||||||
|
// Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least
|
||||||
|
// ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races
|
||||||
|
int Stages,
|
||||||
|
class CtaTileShapeMNK,
|
||||||
|
class Element,
|
||||||
|
class StrideMNL = Stride<_0,_1,_0>,
|
||||||
|
int Alignment = 128 / sizeof_bits_v<Element>
|
||||||
|
>
|
||||||
|
struct Sm90RowOrScalarBroadcast {
|
||||||
|
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
||||||
|
static_assert(
|
||||||
|
(cute::is_same_v<StrideMNL, Stride<_0,_1, _0>>) || // row vector broadcast, e.g. per-col alpha/bias
|
||||||
|
(cute::is_same_v<StrideMNL, Stride<_0,_1,int>>)); // batched row vector broadcast
|
||||||
|
|
||||||
|
// Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem
|
||||||
|
struct SharedStorage {
|
||||||
|
alignas(16) array_aligned<Element, size<1>(CtaTileShapeMNK{}) * Stages> smem_row;
|
||||||
|
};
|
||||||
|
|
||||||
|
// This struct has been modified to have a bool indicating that ptr_row is a
|
||||||
|
// scalar that must be broadcast, instead of containing a scalar that is
|
||||||
|
// valid if ptr_row is null.
|
||||||
|
struct Arguments {
|
||||||
|
Element const* ptr_row = nullptr;
|
||||||
|
bool row_broadcast = true;
|
||||||
|
StrideMNL dRow = {};
|
||||||
|
};
|
||||||
|
|
||||||
|
using Params = Arguments;
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static constexpr Params
|
||||||
|
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static size_t
|
||||||
|
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static cutlass::Status
|
||||||
|
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||||
|
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||||
|
return cutlass::Status::kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Sm90RowOrScalarBroadcast() { }
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||||
|
: params(params),
|
||||||
|
smem_row(const_cast<Element*>(shared_storage.smem_row.data())) { }
|
||||||
|
|
||||||
|
Params params;
|
||||||
|
Element* smem_row;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE bool
|
||||||
|
is_producer_load_needed() const {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE bool
|
||||||
|
is_C_load_needed() const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE bool
|
||||||
|
is_zero() const {
|
||||||
|
return (!params.row_broadcast && *(params.ptr_row) == Element(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int EpiTiles, class GTensor, class STensor>
|
||||||
|
struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks {
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params)
|
||||||
|
: gRow(cute::forward<GTensor>(gRow)),
|
||||||
|
sRow(cute::forward<STensor>(sRow)),
|
||||||
|
params(params) {}
|
||||||
|
|
||||||
|
GTensor gRow; // (CTA_M,CTA_N)
|
||||||
|
STensor sRow; // (CTA_M,CTA_N,PIPE)
|
||||||
|
Params const& params;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) {
|
||||||
|
if (params.ptr_row == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (issue_tma_load) {
|
||||||
|
// Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size
|
||||||
|
constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v<Element> / 8;
|
||||||
|
cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes);
|
||||||
|
// Issue the TMA bulk copy
|
||||||
|
auto bulk_copy = Copy_Atom<SM90_BULK_COPY_AUTO, Element>{}.with(*full_mbarrier_ptr);
|
||||||
|
// Filter so we don't issue redundant copies over stride-0 modes
|
||||||
|
int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
|
||||||
|
copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class... Args>
|
||||||
|
CUTLASS_DEVICE auto
|
||||||
|
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||||
|
|
||||||
|
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||||
|
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||||
|
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
|
||||||
|
Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N)
|
||||||
|
Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
|
||||||
|
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
|
||||||
|
make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
|
||||||
|
|
||||||
|
constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
|
||||||
|
return ProducerLoadCallbacks<EpiTiles, decltype(gRow), decltype(sRow)>(
|
||||||
|
cute::move(gRow), cute::move(sRow), params);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int EpiTiles, class RTensor, class STensor>
|
||||||
|
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params)
|
||||||
|
: tCrRow(cute::forward<RTensor>(tCrRow)),
|
||||||
|
tCsRow(cute::forward<STensor>(tCsRow)),
|
||||||
|
params(params) {}
|
||||||
|
|
||||||
|
RTensor tCrRow; // (CPY,CPY_M,CPY_N)
|
||||||
|
STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
|
||||||
|
Params const& params;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) {
|
||||||
|
if (!params.row_broadcast) {
|
||||||
|
fill(tCrRow, *(params.ptr_row));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (epi_m == 0) { // Assumes M-major subtile loop
|
||||||
|
// Filter so we don't issue redundant copies over stride-0 modes
|
||||||
|
// (only works if 0-strides are in same location, which is by construction)
|
||||||
|
int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
|
||||||
|
copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ElementAccumulator, int FragmentSize>
|
||||||
|
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||||
|
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||||
|
Array<Element, FragmentSize> frg_row;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < FragmentSize; ++i) {
|
||||||
|
frg_row[i] = tCrRow(epi_v * FragmentSize + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
return frg_row;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||||
|
class... Args
|
||||||
|
>
|
||||||
|
CUTLASS_DEVICE auto
|
||||||
|
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||||
|
|
||||||
|
Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
|
||||||
|
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
|
||||||
|
make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
|
||||||
|
Tensor tCsRow = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
|
||||||
|
sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||||
|
Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N)
|
||||||
|
|
||||||
|
constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
|
||||||
|
return ConsumerStoreCallbacks<EpiTiles, decltype(tCrRow), decltype(tCsRow)>(
|
||||||
|
cute::move(tCrRow), cute::move(tCsRow), params);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// Column vector broadcast
|
||||||
|
template<
|
||||||
|
int Stages,
|
||||||
|
class CtaTileShapeMNK,
|
||||||
|
class Element,
|
||||||
|
class StrideMNL = Stride<_1,_0,_0>,
|
||||||
|
int Alignment = 128 / sizeof_bits_v<Element>
|
||||||
|
>
|
||||||
|
struct Sm90ColOrScalarBroadcast {
|
||||||
|
static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
|
||||||
|
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
||||||
|
static_assert(
|
||||||
|
(cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
|
||||||
|
(cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
|
||||||
|
|
||||||
|
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
|
||||||
|
struct SharedStorage { };
|
||||||
|
|
||||||
|
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||||
|
// scalar that must be broadcast, instead of containing a scalar that is
|
||||||
|
// valid if ptr_col is null.
|
||||||
|
struct Arguments {
|
||||||
|
Element const* ptr_col = nullptr;
|
||||||
|
bool col_broadcast = true;
|
||||||
|
StrideMNL dCol = {};
|
||||||
|
};
|
||||||
|
|
||||||
|
using Params = Arguments;
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static constexpr Params
|
||||||
|
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static size_t
|
||||||
|
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static cutlass::Status
|
||||||
|
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||||
|
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||||
|
return cutlass::Status::kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE bool
|
||||||
|
is_producer_load_needed() const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE bool
|
||||||
|
is_C_load_needed() const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE bool
|
||||||
|
is_zero() const {
|
||||||
|
return (!params.col_broadcast && *(params.ptr_col) == Element(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Sm90ColOrScalarBroadcast() { }
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||||
|
: params(params) { }
|
||||||
|
|
||||||
|
Params params;
|
||||||
|
|
||||||
|
template <class... Args>
|
||||||
|
CUTLASS_DEVICE auto
|
||||||
|
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||||
|
return EmptyProducerLoadCallbacks{};
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class GTensor, class RTensor>
|
||||||
|
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, Params const& params)
|
||||||
|
: tCgCol(cute::forward<GTensor>(tCgCol)),
|
||||||
|
tCrCol(cute::forward<RTensor>(tCrCol)),
|
||||||
|
params(params) {}
|
||||||
|
|
||||||
|
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
|
RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
|
Params const& params;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
begin() {
|
||||||
|
if (!params.col_broadcast) {
|
||||||
|
fill(tCrCol, *(params.ptr_col));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter so we don't issue redundant copies over stride-0 modes
|
||||||
|
// (only works if 0-strides are in same location, which is by construction)
|
||||||
|
copy_aligned(filter(tCgCol), filter(tCrCol));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ElementAccumulator, int FragmentSize>
|
||||||
|
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||||
|
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||||
|
Array<Element, FragmentSize> frg_col;
|
||||||
|
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < FragmentSize; ++i) {
|
||||||
|
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
return frg_col;
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||||
|
class... Args
|
||||||
|
>
|
||||||
|
CUTLASS_DEVICE auto
|
||||||
|
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||||
|
|
||||||
|
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||||
|
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
|
||||||
|
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
|
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||||
|
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
|
|
||||||
|
return ConsumerStoreCallbacks<decltype(tCgCol), decltype(tCrCol)>(
|
||||||
|
cute::move(tCgCol), cute::move(tCrCol), params);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
12
csrc/quantization/cutlass_w8a8/common.hpp
Normal file
12
csrc/quantization/cutlass_w8a8/common.hpp
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Helper function for checking CUTLASS errors
|
||||||
|
*/
|
||||||
|
#define CUTLASS_CHECK(status) \
|
||||||
|
{ \
|
||||||
|
TORCH_CHECK(status == cutlass::Status::kSuccess, \
|
||||||
|
cutlassGetStatusString(status)) \
|
||||||
|
}
|
||||||
329
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
Normal file
329
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
Normal file
@@ -0,0 +1,329 @@
|
|||||||
|
#include <stddef.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
// clang-format will break include orders
|
||||||
|
// clang-format off
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
#include "cute/atom/mma_atom.hpp"
|
||||||
|
#include "cutlass/numeric_types.h"
|
||||||
|
|
||||||
|
#include "cutlass/util/device_memory.h"
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/gemm_coord.h"
|
||||||
|
#include "cutlass/arch/mma_sm75.h"
|
||||||
|
#include "cutlass/arch/arch.h"
|
||||||
|
#include "cutlass/arch/mma.h"
|
||||||
|
#include "cutlass/gemm/device/gemm.h"
|
||||||
|
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||||
|
|
||||||
|
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||||
|
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
||||||
|
|
||||||
|
#include "broadcast_load_epilogue_c2x.hpp"
|
||||||
|
#include "common.hpp"
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
/*
|
||||||
|
This defines a quantized GEMM operation with dequantized output, similar to
|
||||||
|
torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for
|
||||||
|
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
|
||||||
|
|
||||||
|
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
|
||||||
|
per-row. B can be quantized per-tensor or per-column.
|
||||||
|
Any combination of per-tensor and per-row or column is supported.
|
||||||
|
A and B must have symmetric quantization (zero point == 0).
|
||||||
|
|
||||||
|
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||||
|
scales are applied elementwise with numpy-style broadcasting.
|
||||||
|
|
||||||
|
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||||
|
the A and B operands respectively. These scales may be either per-tensor or
|
||||||
|
per row or column.
|
||||||
|
*/
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Wrappers for the GEMM kernel that is used to guard against compilation on
|
||||||
|
// architectures that will never use the kernel. The purpose of this is to
|
||||||
|
// reduce the size of the compiled binary.
|
||||||
|
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||||
|
// into code that will be executed on the device where it is defined.
|
||||||
|
template <typename Kernel>
|
||||||
|
struct enable_sm75_to_sm80 : Kernel {
|
||||||
|
template <typename... Args>
|
||||||
|
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||||
|
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
|
||||||
|
Kernel::invoke(std::forward<Args>(args)...);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Kernel>
|
||||||
|
struct enable_sm80_to_sm89 : Kernel {
|
||||||
|
template <typename... Args>
|
||||||
|
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||||
|
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
|
||||||
|
Kernel::invoke(std::forward<Args>(args)...);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Kernel>
|
||||||
|
struct enable_sm89_to_sm90 : Kernel {
|
||||||
|
template <typename... Args>
|
||||||
|
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||||
|
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
|
||||||
|
Kernel::invoke(std::forward<Args>(args)...);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Arch, template <typename> typename ArchGuard,
|
||||||
|
typename ElementAB_, typename ElementD_, typename TileShape,
|
||||||
|
typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
|
||||||
|
struct cutlass_2x_gemm {
|
||||||
|
using ElementAB = ElementAB_;
|
||||||
|
using ElementD = ElementD_;
|
||||||
|
|
||||||
|
using ElementAcc =
|
||||||
|
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||||
|
float>::type;
|
||||||
|
|
||||||
|
using Operator =
|
||||||
|
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
|
||||||
|
cutlass::arch::OpMultiplyAddSaturate,
|
||||||
|
cutlass::arch::OpMultiplyAdd>::type;
|
||||||
|
|
||||||
|
using OutputTileThreadMap =
|
||||||
|
cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
||||||
|
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
|
||||||
|
>;
|
||||||
|
|
||||||
|
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||||
|
|
||||||
|
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
||||||
|
OutputTileThreadMap, float, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||||
|
|
||||||
|
using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
|
||||||
|
OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||||
|
|
||||||
|
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::multiplies, float, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTCompute0 =
|
||||||
|
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||||
|
|
||||||
|
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::multiplies, ElementD, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTCompute1 =
|
||||||
|
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
|
||||||
|
|
||||||
|
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
|
||||||
|
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
|
||||||
|
Stride<int64_t, Int<1>, Int<0>>>;
|
||||||
|
|
||||||
|
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute1>;
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
using RowMajor = typename cutlass::layout::RowMajor;
|
||||||
|
using ColumnMajor = typename cutlass::layout::ColumnMajor;
|
||||||
|
using KernelType =
|
||||||
|
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
||||||
|
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
|
||||||
|
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
|
||||||
|
float, cutlass::layout::RowMajor, 4,
|
||||||
|
ElementAcc, float, cutlass::arch::OpClassTensorOp,
|
||||||
|
Arch,
|
||||||
|
TileShape, WarpShape, InstructionShape,
|
||||||
|
EVTD,
|
||||||
|
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
|
||||||
|
MainLoopStages, Operator,
|
||||||
|
1 /* epilogue stages */
|
||||||
|
>::GemmKernel>;
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Gemm>
|
||||||
|
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales) {
|
||||||
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
|
using ElementD = typename Gemm::ElementD;
|
||||||
|
|
||||||
|
int32_t m = a.size(0);
|
||||||
|
int32_t n = b.size(1);
|
||||||
|
int32_t k = a.size(1);
|
||||||
|
cutlass::gemm::GemmCoord problem_size{m, n, k};
|
||||||
|
|
||||||
|
int64_t lda = a.stride(0);
|
||||||
|
int64_t ldb = b.stride(1);
|
||||||
|
int64_t ldc = out.stride(0);
|
||||||
|
|
||||||
|
using StrideC = Stride<int64_t, Int<1>, Int<0>>;
|
||||||
|
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||||
|
|
||||||
|
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
|
||||||
|
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
|
||||||
|
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||||
|
|
||||||
|
auto a_scales_ptr = a_scales.data_ptr<float>();
|
||||||
|
auto b_scales_ptr = b_scales.data_ptr<float>();
|
||||||
|
|
||||||
|
using ScaleAArgs = typename Gemm::ScaleA::Arguments;
|
||||||
|
using ScaleBArgs = typename Gemm::ScaleB::Arguments;
|
||||||
|
|
||||||
|
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
||||||
|
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
||||||
|
|
||||||
|
typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args};
|
||||||
|
|
||||||
|
typename Gemm::EVTCompute1::Arguments evt1_compute_args{a_args,
|
||||||
|
evt0_compute_args};
|
||||||
|
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
|
||||||
|
|
||||||
|
typename Gemm::EVTD::Arguments epilogue_args{
|
||||||
|
evt1_compute_args,
|
||||||
|
d_args,
|
||||||
|
};
|
||||||
|
|
||||||
|
typename Gemm::Op::Arguments args{
|
||||||
|
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode
|
||||||
|
problem_size, // problem size
|
||||||
|
1, // batch count
|
||||||
|
epilogue_args,
|
||||||
|
a_ptr,
|
||||||
|
b_ptr,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
ldc,
|
||||||
|
ldc};
|
||||||
|
|
||||||
|
// Launch the CUTLASS GEMM kernel.
|
||||||
|
typename Gemm::Op gemm_op;
|
||||||
|
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||||
|
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||||
|
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||||
|
|
||||||
|
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||||
|
cutlass::Status status = gemm_op(args, workspace.get(), stream);
|
||||||
|
CUTLASS_CHECK(status);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales) {
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||||
|
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||||
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
|
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||||
|
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
|
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||||
|
|
||||||
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
|
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||||
|
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t,
|
||||||
|
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
|
||||||
|
b_scales);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
|
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||||
|
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
|
||||||
|
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
|
||||||
|
b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales) {
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||||
|
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||||
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
|
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||||
|
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
|
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||||
|
|
||||||
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
|
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||||
|
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::bfloat16_t,
|
||||||
|
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
||||||
|
b_scales);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
|
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||||
|
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::half_t,
|
||||||
|
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
||||||
|
b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales) {
|
||||||
|
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||||
|
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
|
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||||
|
|
||||||
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
|
if (a.dtype() == torch::kInt8) {
|
||||||
|
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||||
|
|
||||||
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
|
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||||
|
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t,
|
||||||
|
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
||||||
|
b_scales);
|
||||||
|
} else {
|
||||||
|
assert(out.dtype() == torch::kFloat16);
|
||||||
|
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||||
|
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t,
|
||||||
|
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
|
||||||
|
b_scales);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
|
||||||
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
|
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||||
|
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
|
||||||
|
cutlass::bfloat16_t, TileShape, WarpShape, InstructionShape, 5>>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
|
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
|
||||||
|
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
|
||||||
|
cutlass::half_t, TileShape, WarpShape, InstructionShape, 5>>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
340
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
Normal file
340
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
// clang-format will break include orders
|
||||||
|
// clang-format off
|
||||||
|
#include <cudaTypedefs.h>
|
||||||
|
|
||||||
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||||
|
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
#include "cute/atom/mma_atom.hpp"
|
||||||
|
#include "cutlass/numeric_types.h"
|
||||||
|
|
||||||
|
#include "cutlass/util/device_memory.h"
|
||||||
|
|
||||||
|
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||||
|
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||||
|
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||||
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||||
|
|
||||||
|
#include "broadcast_load_epilogue_c3x.hpp"
|
||||||
|
#include "common.hpp"
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
/*
|
||||||
|
This defines a quantized GEMM operation with dequantized output, similar to
|
||||||
|
torch._scaled_mm. It is defined using the CUTLASS 3.x API, and is used for
|
||||||
|
NVIDIA GPUs with sm90a (Hopper) or later.
|
||||||
|
|
||||||
|
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
|
||||||
|
per-row. B can be quantized per-tensor or per-column.
|
||||||
|
Any combination of per-tensor and per-row or column is supported.
|
||||||
|
A and B must have symmetric quantization (zero point == 0).
|
||||||
|
|
||||||
|
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||||
|
scales are applied elementwise with numpy-style broadcasting.
|
||||||
|
|
||||||
|
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||||
|
the A and B operands respectively. These scales may be either per-tensor or
|
||||||
|
per row or column.
|
||||||
|
*/
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
uint32_t next_pow_2(uint32_t const num) {
|
||||||
|
if (num <= 1) return num;
|
||||||
|
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
// A wrapper for the GEMM kernel that is used to guard against compilation on
|
||||||
|
// architectures that will never use the kernel. The purpose of this is to
|
||||||
|
// reduce the size of the compiled binary.
|
||||||
|
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||||
|
// into code that will be executed on the device where it is defined.
|
||||||
|
template <typename Kernel>
|
||||||
|
struct enable_sm90_or_later : Kernel {
|
||||||
|
template <typename... Args>
|
||||||
|
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||||
|
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
||||||
|
Kernel::operator()(std::forward<Args>(args)...);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename ElementAB_, typename ElementD_, typename TileShape,
|
||||||
|
typename ClusterShape, typename KernelSchedule,
|
||||||
|
typename EpilogueSchedule>
|
||||||
|
struct cutlass_3x_gemm {
|
||||||
|
using ElementAB = ElementAB_;
|
||||||
|
using ElementD = ElementD_;
|
||||||
|
using ElementAcc =
|
||||||
|
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||||
|
float>::type;
|
||||||
|
|
||||||
|
using EpilogueDescriptor =
|
||||||
|
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||||
|
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
||||||
|
ElementD, EpilogueSchedule>;
|
||||||
|
|
||||||
|
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||||
|
|
||||||
|
using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
||||||
|
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
|
||||||
|
Stride<Int<1>, Int<0>, Int<0>>>;
|
||||||
|
|
||||||
|
using ScaleBDescriptor =
|
||||||
|
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
|
||||||
|
EpilogueDescriptor, float>;
|
||||||
|
|
||||||
|
using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
||||||
|
ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape,
|
||||||
|
typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||||
|
|
||||||
|
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::multiplies, float, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTCompute0 =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||||
|
|
||||||
|
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::multiplies, ElementD, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTCompute1 =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||||
|
|
||||||
|
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
||||||
|
using ElementC = void;
|
||||||
|
using StrideC = StrideD;
|
||||||
|
|
||||||
|
using CollectiveEpilogue =
|
||||||
|
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
||||||
|
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||||
|
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
|
||||||
|
EpilogueSchedule, EVTCompute1>::CollectiveOp;
|
||||||
|
|
||||||
|
static constexpr size_t CEStorageSize =
|
||||||
|
sizeof(typename CollectiveEpilogue::SharedStorage);
|
||||||
|
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
||||||
|
static_cast<int>(CEStorageSize)>;
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
using CollectiveMainloop =
|
||||||
|
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||||
|
ElementAB, cutlass::layout::RowMajor, 16,
|
||||||
|
ElementAB, cutlass::layout::ColumnMajor, 16,
|
||||||
|
ElementAcc, TileShape, ClusterShape,
|
||||||
|
Stages,
|
||||||
|
KernelSchedule>::CollectiveOp;
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||||
|
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||||
|
cutlass::gemm::PersistentScheduler>>;
|
||||||
|
|
||||||
|
struct GemmKernel : public KernelType {};
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Gemm>
|
||||||
|
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales) {
|
||||||
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
|
using ElementD = typename Gemm::ElementD;
|
||||||
|
|
||||||
|
int32_t m = a.size(0);
|
||||||
|
int32_t n = b.size(1);
|
||||||
|
int32_t k = a.size(1);
|
||||||
|
|
||||||
|
int64_t lda = a.stride(0);
|
||||||
|
int64_t ldb = b.stride(1);
|
||||||
|
int64_t ldc = out.stride(0);
|
||||||
|
|
||||||
|
using StrideA = Stride<int64_t, Int<1>, Int<0>>;
|
||||||
|
using StrideB = Stride<int64_t, Int<1>, Int<0>>;
|
||||||
|
using StrideC = typename Gemm::StrideC;
|
||||||
|
|
||||||
|
StrideA a_stride{lda, Int<1>{}, Int<0>{}};
|
||||||
|
StrideB b_stride{ldb, Int<1>{}, Int<0>{}};
|
||||||
|
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||||
|
|
||||||
|
using GemmKernel = typename Gemm::GemmKernel;
|
||||||
|
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
|
||||||
|
|
||||||
|
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||||
|
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||||
|
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
||||||
|
b_stride};
|
||||||
|
|
||||||
|
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||||
|
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||||
|
{}, c_ptr, c_stride, c_ptr, c_stride};
|
||||||
|
|
||||||
|
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||||
|
prob_shape, mainloop_args, epilogue_args};
|
||||||
|
|
||||||
|
using ScaleA_Args = typename Gemm::ScaleA::Arguments;
|
||||||
|
using ScaleB_Args = typename Gemm::ScaleB::Arguments;
|
||||||
|
|
||||||
|
ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
||||||
|
ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
||||||
|
|
||||||
|
args.epilogue.thread = {a_args, {b_args}};
|
||||||
|
|
||||||
|
// Launch the CUTLASS GEMM kernel.
|
||||||
|
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||||
|
GemmOp gemm_op;
|
||||||
|
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||||
|
|
||||||
|
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||||
|
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||||
|
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||||
|
|
||||||
|
cutlass::Status status = gemm_op.run(args, workspace.get(), stream);
|
||||||
|
CUTLASS_CHECK(status);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename InType, typename OutType, int32_t M>
|
||||||
|
struct sm90_fp8_config {
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule =
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_128, _128, _128>;
|
||||||
|
using ClusterShape = Shape<_2, _1, _1>;
|
||||||
|
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
|
||||||
|
EpilogueSchedule>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType>
|
||||||
|
struct sm90_fp8_config<InType, OutType, 128> {
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule =
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_64, _128, _128>;
|
||||||
|
using ClusterShape = Shape<_2, _1, _1>;
|
||||||
|
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
|
||||||
|
EpilogueSchedule>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType>
|
||||||
|
struct sm90_fp8_config<InType, OutType, 64> {
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule =
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
using TileShape = Shape<_64, _64, _128>;
|
||||||
|
using ClusterShape = Shape<_1, _8, _1>;
|
||||||
|
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
|
||||||
|
EpilogueSchedule>;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
template <typename InType, typename OutType>
|
||||||
|
void cutlass_scaled_mm_dq_sm90_fp8_dispatch(torch::Tensor& out,
|
||||||
|
torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales) {
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
|
using Cutlass3xGemmDefault =
|
||||||
|
typename sm90_fp8_config<InType, OutType, 0>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemmM64 =
|
||||||
|
typename sm90_fp8_config<InType, OutType, 64>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemmM128 =
|
||||||
|
typename sm90_fp8_config<InType, OutType, 128>::Cutlass3xGemm;
|
||||||
|
|
||||||
|
uint32_t const m = a.size(0);
|
||||||
|
uint32_t const mp2 =
|
||||||
|
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
|
||||||
|
|
||||||
|
if (mp2 <= 64) {
|
||||||
|
// m in [1, 64]
|
||||||
|
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmM64>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
} else if (mp2 <= 128) {
|
||||||
|
// m in (64, 128]
|
||||||
|
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmM128>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
} else {
|
||||||
|
// m in (128, inf)
|
||||||
|
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmDefault>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales) {
|
||||||
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
|
if (a.dtype() == torch::kInt8) {
|
||||||
|
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||||
|
|
||||||
|
using TileShape = Shape<_128, _128, _128>;
|
||||||
|
using ClusterShape = Shape<_1, _2, _1>;
|
||||||
|
using KernelSchedule =
|
||||||
|
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||||
|
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||||
|
|
||||||
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
|
return cutlass_scaled_mm_dq_dispatcher<
|
||||||
|
cutlass_3x_gemm<int8_t, cutlass::bfloat16_t, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule>>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
|
|
||||||
|
return cutlass_scaled_mm_dq_dispatcher<
|
||||||
|
cutlass_3x_gemm<int8_t, cutlass::half_t, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule>>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
|
||||||
|
if (out.dtype() == torch::kBFloat16) {
|
||||||
|
return cutlass_scaled_mm_dq_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||||
|
cutlass::bfloat16_t>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||||
|
return cutlass_scaled_mm_dq_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||||
|
cutlass::half_t>(
|
||||||
|
out, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
75
csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
Normal file
75
csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
#include <cudaTypedefs.h>
|
||||||
|
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_dq_sm80(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales);
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales);
|
||||||
|
|
||||||
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||||
|
void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b,
|
||||||
|
torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales) {
|
||||||
|
int32_t major_capability;
|
||||||
|
int32_t minor_capability;
|
||||||
|
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
||||||
|
0);
|
||||||
|
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||||
|
0);
|
||||||
|
int32_t version_num = major_capability * 10 + minor_capability;
|
||||||
|
|
||||||
|
// Checks for conformality
|
||||||
|
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||||
|
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||||
|
b.size(1) == c.size(1));
|
||||||
|
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||||
|
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||||
|
|
||||||
|
// Check for strides and alignment
|
||||||
|
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||||
|
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||||
|
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||||
|
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||||
|
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||||
|
|
||||||
|
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||||
|
|
||||||
|
if (version_num >= 90) {
|
||||||
|
// Hopper
|
||||||
|
|
||||||
|
// Guard against compilation issues for sm90 kernels
|
||||||
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||||
|
cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales);
|
||||||
|
#else
|
||||||
|
cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales);
|
||||||
|
#endif
|
||||||
|
} else if (version_num == 89) {
|
||||||
|
// Ada Lovelace
|
||||||
|
cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales);
|
||||||
|
} else if (version_num >= 80) {
|
||||||
|
// Ampere
|
||||||
|
cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales);
|
||||||
|
} else {
|
||||||
|
// Turing
|
||||||
|
TORCH_CHECK(version_num >= 75);
|
||||||
|
cutlass_scaled_mm_dq_sm75(c, a, b, a_scales, b_scales);
|
||||||
|
}
|
||||||
|
}
|
||||||
137
csrc/quantization/fp8/amd/hip_float8.h
Normal file
137
csrc/quantization/fp8/amd/hip_float8.h
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#ifdef __HIPCC__
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
#else
|
||||||
|
#include <type_traits>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include <iostream>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "hip_float8_impl.h"
|
||||||
|
|
||||||
|
struct alignas(1) hip_fp8 {
|
||||||
|
struct from_bits_t {};
|
||||||
|
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||||
|
return from_bits_t();
|
||||||
|
}
|
||||||
|
uint8_t data;
|
||||||
|
|
||||||
|
hip_fp8() = default;
|
||||||
|
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
|
||||||
|
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
|
||||||
|
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
|
||||||
|
: data(v) {}
|
||||||
|
|
||||||
|
#ifdef __HIP__MI300__
|
||||||
|
// NOTE: ON-DEVICE... always optimal bias
|
||||||
|
explicit HIP_FP8_DEVICE hip_fp8(float v)
|
||||||
|
: data(hip_fp8_impl::to_fp8_from_fp32(v)) {}
|
||||||
|
|
||||||
|
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
|
||||||
|
: hip_fp8(static_cast<float>(v)) {}
|
||||||
|
|
||||||
|
// Host only implementation using s/w simulation
|
||||||
|
explicit HIP_FP8_HOST
|
||||||
|
#else // __HIP__MI300__
|
||||||
|
// both Host and DEVICE for non-MI300 using s/w simulation
|
||||||
|
explicit HIP_FP8_HOST_DEVICE
|
||||||
|
#endif // __HIP__MI300__
|
||||||
|
hip_fp8(float v) {
|
||||||
|
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/,
|
||||||
|
true /*clip*/>(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
|
||||||
|
: hip_fp8(static_cast<float>(v)) {}
|
||||||
|
|
||||||
|
#ifdef __HIP__MI300__
|
||||||
|
// upcast using device specific intrinsic
|
||||||
|
explicit inline HIP_FP8_DEVICE operator float() const {
|
||||||
|
float fval;
|
||||||
|
uint32_t i32val = static_cast<uint32_t>(data);
|
||||||
|
|
||||||
|
// upcast
|
||||||
|
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
|
||||||
|
: "=v"(fval)
|
||||||
|
: "v"(i32val));
|
||||||
|
|
||||||
|
return fval;
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit inline HIP_FP8_HOST operator float() const
|
||||||
|
#else // __HIP__MI300__
|
||||||
|
explicit inline HIP_FP8_HOST_DEVICE operator float() const
|
||||||
|
#endif // __HIP__MI300__
|
||||||
|
{
|
||||||
|
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(
|
||||||
|
data);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace std {
|
||||||
|
inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); }
|
||||||
|
inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); }
|
||||||
|
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; }
|
||||||
|
} // namespace std
|
||||||
|
|
||||||
|
// Special operator overloading
|
||||||
|
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) {
|
||||||
|
return os << float(f8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// all + operator overloading with mixed types
|
||||||
|
// mixed types, always converts to f32, does computation in f32, and returns
|
||||||
|
// float
|
||||||
|
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) {
|
||||||
|
return (fa + float(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) {
|
||||||
|
return (float(a) + fb);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) {
|
||||||
|
return hip_fp8(float(a) + float(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) {
|
||||||
|
return a = hip_fp8(float(a) + float(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
// overloading multiplication, always returns float,
|
||||||
|
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) {
|
||||||
|
return float(a) * float(b);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) {
|
||||||
|
return (a * float(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) {
|
||||||
|
return (float(a) * b);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) {
|
||||||
|
return ((float)a * float(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) {
|
||||||
|
return ((float)a * float(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
// overloading for compare
|
||||||
|
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) {
|
||||||
|
return (a.data == b.data);
|
||||||
|
}
|
||||||
|
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) {
|
||||||
|
return (a.data != b.data);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) {
|
||||||
|
return static_cast<float>(a) >= static_cast<float>(b);
|
||||||
|
}
|
||||||
|
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) {
|
||||||
|
return static_cast<float>(a) > static_cast<float>(b);
|
||||||
|
}
|
||||||
316
csrc/quantization/fp8/amd/hip_float8_impl.h
Normal file
316
csrc/quantization/fp8/amd/hip_float8_impl.h
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#if defined(__HIPCC__) && \
|
||||||
|
(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||||
|
#define __HIP__MI300__
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef __HIPCC__
|
||||||
|
#define HIP_FP8_HOST_DEVICE __host__ __device__
|
||||||
|
#define HIP_FP8_HOST __host__
|
||||||
|
#define HIP_FP8_DEVICE __device__
|
||||||
|
#else
|
||||||
|
#define HIP_FP8_HOST_DEVICE
|
||||||
|
#define HIP_FP8_HOST
|
||||||
|
#define HIP_FP8_DEVICE
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace hip_fp8_impl {
|
||||||
|
|
||||||
|
#ifdef __HIP__MI300__
|
||||||
|
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) {
|
||||||
|
uint8_t i8data;
|
||||||
|
union {
|
||||||
|
float fval;
|
||||||
|
uint32_t i32val;
|
||||||
|
uint8_t i8val[4]; // NOTE: not endian independent
|
||||||
|
} val;
|
||||||
|
|
||||||
|
uint32_t ival = 0;
|
||||||
|
val.fval = v;
|
||||||
|
|
||||||
|
if ((val.i32val & 0x7F800000) !=
|
||||||
|
0x7F800000) { /// propagate NAN/INF, no clipping
|
||||||
|
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
|
||||||
|
false); // false -> WORD0
|
||||||
|
val.i32val = ival;
|
||||||
|
i8data = val.i8val[0];
|
||||||
|
|
||||||
|
return i8data;
|
||||||
|
}
|
||||||
|
#endif // __HIP__MI300__
|
||||||
|
|
||||||
|
HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
|
||||||
|
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
|
||||||
|
HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); }
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
|
||||||
|
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false,
|
||||||
|
uint32_t rng = 0) {
|
||||||
|
#ifdef __HIPCC__
|
||||||
|
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||||
|
#else
|
||||||
|
constexpr bool is_half = false;
|
||||||
|
#endif
|
||||||
|
constexpr bool is_float = std::is_same<T, float>::value;
|
||||||
|
static_assert(wm + we == 7, "wm+we==7");
|
||||||
|
static_assert(is_half || is_float, "Only half and float can be cast to f8");
|
||||||
|
|
||||||
|
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
|
||||||
|
uint32_t x;
|
||||||
|
if (sizeof(T) == 4) {
|
||||||
|
x = reinterpret_cast<uint32_t&>(_x);
|
||||||
|
} else {
|
||||||
|
x = reinterpret_cast<uint16_t&>(_x);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t head, mantissa;
|
||||||
|
int exponent, bias;
|
||||||
|
uint32_t sign;
|
||||||
|
|
||||||
|
if (sizeof(T) == 4) {
|
||||||
|
head = x & 0xFF800000;
|
||||||
|
mantissa = x & 0x7FFFFF;
|
||||||
|
exponent = (head >> 23) & 0xFF;
|
||||||
|
sign = head >> 31;
|
||||||
|
bias = 127;
|
||||||
|
} else {
|
||||||
|
head = x & 0xFC00;
|
||||||
|
mantissa = x & 0x3FF;
|
||||||
|
exponent = (head >> 10) & 0x1F;
|
||||||
|
sign = head >> 15;
|
||||||
|
bias = 15;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
|
||||||
|
|
||||||
|
// Deal with inf and NaNs
|
||||||
|
if (negative_zero_nan) {
|
||||||
|
if (sizeof(T) == 4) {
|
||||||
|
if ((x & 0x7F800000) == 0x7F800000) {
|
||||||
|
return 0x80;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// if(__hisinf(x) || __hisnan(x))
|
||||||
|
if ((x & 0x7C00) == 0x7C00) {
|
||||||
|
return 0x80;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (sizeof(T) == 4) {
|
||||||
|
if ((x & 0x7F800000) == 0x7F800000) {
|
||||||
|
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if ((x & 0x7C00) == 0x7C00) {
|
||||||
|
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (x == 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// First need to check if it is normal or denorm as there is a difference of
|
||||||
|
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
|
||||||
|
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
|
||||||
|
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
|
||||||
|
// need to check whether there is carry and adjust exponent and mantissa again
|
||||||
|
|
||||||
|
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
|
||||||
|
// bits
|
||||||
|
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
||||||
|
const int f8_denormal_act_exponent =
|
||||||
|
1 - f8_bias; // actual exponent of f8 denormal
|
||||||
|
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
||||||
|
// f8_exponent is the converted f8 exponent with bias encoding
|
||||||
|
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
||||||
|
// the difference needs to be adjusted and mantissa shifted
|
||||||
|
int act_exponent, f8_exponent, exponent_diff;
|
||||||
|
|
||||||
|
if (exponent == 0) { // fp32/fp16 is in denormal.
|
||||||
|
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
|
||||||
|
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
|
||||||
|
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
|
||||||
|
exponent bias 16. It means that there are some numbers in fp16 denormal but they
|
||||||
|
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
||||||
|
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
||||||
|
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
|
||||||
|
act_exponent = exponent - bias + 1;
|
||||||
|
exponent_diff =
|
||||||
|
f8_denormal_act_exponent -
|
||||||
|
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
||||||
|
} else { // fp32/fp16 is normal with implicit 1
|
||||||
|
act_exponent = exponent - bias;
|
||||||
|
if (act_exponent <= f8_denormal_act_exponent) {
|
||||||
|
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
|
||||||
|
range. For example fp8 nanoo mode, denormal exponent is -7, but if the
|
||||||
|
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
|
||||||
|
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
|
||||||
|
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
||||||
|
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
||||||
|
} else { // both fp32/fp16 and f8 are in normal range
|
||||||
|
exponent_diff = 0; // exponent_diff=0 does not mean there is no
|
||||||
|
// difference for this case, act_exponent could be
|
||||||
|
// larger. Just that it does not need shift mantissa
|
||||||
|
}
|
||||||
|
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
|
||||||
|
}
|
||||||
|
|
||||||
|
bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
|
||||||
|
static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
|
||||||
|
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
|
||||||
|
done before we shift right as shift right could rip off some residual part
|
||||||
|
and make something not midpoint look like midpoint. For example, the fp16
|
||||||
|
number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
|
||||||
|
shift right by 4 bits, it would look like midpoint.
|
||||||
|
*/
|
||||||
|
|
||||||
|
if (exponent_diff > 0) {
|
||||||
|
mantissa >>= exponent_diff;
|
||||||
|
} else if (exponent_diff == -1) {
|
||||||
|
mantissa <<= -exponent_diff;
|
||||||
|
}
|
||||||
|
bool implicit_one = mantissa & (1 << mfmt);
|
||||||
|
// if there is no implicit 1, it means the f8 is denormal and need to adjust
|
||||||
|
// to denorm exponent
|
||||||
|
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ +
|
||||||
|
f8_bias - (implicit_one ? 0 : 1);
|
||||||
|
|
||||||
|
// Now we have the exponent and mantissa adjusted
|
||||||
|
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
|
||||||
|
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit
|
||||||
|
// that is not truncated is 1
|
||||||
|
mantissa +=
|
||||||
|
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) &
|
||||||
|
drop_mask;
|
||||||
|
|
||||||
|
// Now we deal with overflow
|
||||||
|
if (f8_exponent == 0) {
|
||||||
|
if ((1 << mfmt) & mantissa) {
|
||||||
|
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if ((1 << (mfmt + 1)) & mantissa) {
|
||||||
|
mantissa >>= 1;
|
||||||
|
f8_exponent++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mantissa >>= (mfmt - wm);
|
||||||
|
|
||||||
|
// above range: quantize to maximum possible float of the same sign
|
||||||
|
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
|
||||||
|
if (f8_exponent > max_exp) {
|
||||||
|
if (clip) {
|
||||||
|
mantissa = (1 << wm) - 1;
|
||||||
|
f8_exponent = max_exp;
|
||||||
|
} else {
|
||||||
|
return signed_inf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (f8_exponent == 0 && mantissa == 0) {
|
||||||
|
return negative_zero_nan ? 0 : (sign << 7);
|
||||||
|
}
|
||||||
|
mantissa &= (1 << wm) - 1;
|
||||||
|
return (sign << 7) | (f8_exponent << wm) | mantissa;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int we, int wm, typename T = float, bool negative_zero_nan = true>
|
||||||
|
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) {
|
||||||
|
#ifdef __HIPCC__
|
||||||
|
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||||
|
#else
|
||||||
|
constexpr bool is_half = false;
|
||||||
|
#endif
|
||||||
|
constexpr bool is_float = std::is_same<T, float>::value;
|
||||||
|
static_assert(is_half || is_float, "only half and float are supported");
|
||||||
|
|
||||||
|
constexpr int weo = is_half ? 5 : 8;
|
||||||
|
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
|
||||||
|
|
||||||
|
T fInf, fNegInf, fNaN, fNeg0;
|
||||||
|
|
||||||
|
#ifdef __HIPCC__
|
||||||
|
if (is_half) {
|
||||||
|
const uint16_t ihInf = 0x7C00;
|
||||||
|
const uint16_t ihNegInf = 0xFC00;
|
||||||
|
const uint16_t ihNaN = 0x7C01;
|
||||||
|
const uint16_t ihNeg0 = 0x8000;
|
||||||
|
fInf = reinterpret_cast<const _Float16&>(ihInf);
|
||||||
|
fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
|
||||||
|
fNaN = reinterpret_cast<const _Float16&>(ihNaN);
|
||||||
|
fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
|
||||||
|
} else
|
||||||
|
#endif
|
||||||
|
if (is_float) {
|
||||||
|
const uint32_t ifInf = 0x7F800000;
|
||||||
|
const uint32_t ifNegInf = 0xFF800000;
|
||||||
|
const uint32_t ifNaN = 0x7F800001;
|
||||||
|
const uint32_t ifNeg0 = 0x80000000;
|
||||||
|
fInf = reinterpret_cast<const float&>(ifInf);
|
||||||
|
fNegInf = reinterpret_cast<const float&>(ifNegInf);
|
||||||
|
fNaN = reinterpret_cast<const float&>(ifNaN);
|
||||||
|
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (x == 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t sign = x >> 7;
|
||||||
|
uint32_t mantissa = x & ((1 << wm) - 1);
|
||||||
|
int exponent = (x & 0x7F) >> wm;
|
||||||
|
if (negative_zero_nan) {
|
||||||
|
if (x == 0x80) {
|
||||||
|
return fNaN;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (x == 0x80) {
|
||||||
|
return fNeg0;
|
||||||
|
}
|
||||||
|
if (exponent == ((1 << we) - 1)) {
|
||||||
|
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
|
||||||
|
if (we == 5 && is_half && !negative_zero_nan) {
|
||||||
|
retval = x << 8;
|
||||||
|
return reinterpret_cast<const T&>(retval);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int exp_low_cutoff =
|
||||||
|
(1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||||
|
|
||||||
|
// subnormal input
|
||||||
|
if (exponent == 0) {
|
||||||
|
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
||||||
|
int sh = 1 + clz(mantissa) - (32 - wm);
|
||||||
|
mantissa <<= sh;
|
||||||
|
exponent += 1 - sh;
|
||||||
|
mantissa &= ((1 << wm) - 1);
|
||||||
|
}
|
||||||
|
exponent += exp_low_cutoff - 1;
|
||||||
|
mantissa <<= wmo - wm;
|
||||||
|
|
||||||
|
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
||||||
|
if (exponent <= 0) {
|
||||||
|
mantissa |= 1 << wmo;
|
||||||
|
mantissa >>= 1 - exponent;
|
||||||
|
exponent = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sizeof(T) == 2) {
|
||||||
|
retval = (sign << 15) | (exponent << 10) | mantissa;
|
||||||
|
} else {
|
||||||
|
retval = (sign << 31) | (exponent << 23) | mantissa;
|
||||||
|
}
|
||||||
|
return reinterpret_cast<const T&>(retval);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace hip_fp8_impl
|
||||||
575
csrc/quantization/fp8/amd/quant_utils.cuh
Normal file
575
csrc/quantization/fp8/amd/quant_utils.cuh
Normal file
@@ -0,0 +1,575 @@
|
|||||||
|
#pragma once
|
||||||
|
#include "hip_float8.h"
|
||||||
|
|
||||||
|
#include <hip/hip_fp16.h>
|
||||||
|
#include <hip/hip_bf16.h>
|
||||||
|
#include <hip/hip_bfloat16.h>
|
||||||
|
|
||||||
|
#include "../../../attention/dtype_fp8.cuh"
|
||||||
|
#include "../../../attention/dtype_float32.cuh"
|
||||||
|
#include "../../../attention/dtype_bfloat16.cuh"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
|
||||||
|
namespace fp8 {
|
||||||
|
#ifdef ENABLE_FP8
|
||||||
|
|
||||||
|
template <typename Tout, typename Tin>
|
||||||
|
__inline__ __device__ Tout vec_conversion(const Tin& x) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Tout, typename Tin>
|
||||||
|
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
|
||||||
|
const float scale) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> half
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint16_t
|
||||||
|
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
|
||||||
|
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||||
|
__half_raw res;
|
||||||
|
res.data = static_cast<float>(f8);
|
||||||
|
return res.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> half2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint32_t
|
||||||
|
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
|
||||||
|
#if defined(__HIP__MI300__) && \
|
||||||
|
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||||
|
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||||
|
union {
|
||||||
|
__half2_raw h2r;
|
||||||
|
uint32_t ui32;
|
||||||
|
} tmp;
|
||||||
|
tmp.h2r.x.data = f2[0];
|
||||||
|
tmp.h2r.y.data = f2[1];
|
||||||
|
return tmp.ui32;
|
||||||
|
#else
|
||||||
|
union {
|
||||||
|
uint16_t u16[2];
|
||||||
|
uint32_t u32;
|
||||||
|
} tmp;
|
||||||
|
|
||||||
|
tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
|
||||||
|
tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
||||||
|
return tmp.u32;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> half2x2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
|
||||||
|
union {
|
||||||
|
uint2 u32x2;
|
||||||
|
uint32_t u32[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
||||||
|
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
||||||
|
return tmp.u32x2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> half2x4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
|
||||||
|
union {
|
||||||
|
uint4 u64x2;
|
||||||
|
uint2 u64[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
||||||
|
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
||||||
|
return tmp.u64x2;
|
||||||
|
}
|
||||||
|
|
||||||
|
using __nv_bfloat16 = __hip_bfloat16;
|
||||||
|
|
||||||
|
// fp8 -> __nv_bfloat16
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ __nv_bfloat16
|
||||||
|
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
|
||||||
|
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||||
|
float f{f8};
|
||||||
|
return __float2bfloat16(f);
|
||||||
|
}
|
||||||
|
|
||||||
|
using __nv_bfloat162 = __hip_bfloat162;
|
||||||
|
|
||||||
|
// fp8x2 -> __nv_bfloat162
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ __nv_bfloat162
|
||||||
|
vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
|
||||||
|
__nv_bfloat162 res;
|
||||||
|
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
||||||
|
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> bf16_4_t
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_4_t
|
||||||
|
vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
|
||||||
|
bf16_4_t res;
|
||||||
|
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
||||||
|
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> bf16_8_t
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
|
||||||
|
bf16_4_t tmp1, tmp2;
|
||||||
|
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
||||||
|
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
||||||
|
bf16_8_t res;
|
||||||
|
res.x = tmp1.x;
|
||||||
|
res.y = tmp1.y;
|
||||||
|
res.z = tmp2.x;
|
||||||
|
res.w = tmp2.y;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> float
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
|
||||||
|
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
||||||
|
return static_cast<float>(fp8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> float2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float2
|
||||||
|
vec_conversion<float2, uint16_t>(const uint16_t& a) {
|
||||||
|
#if defined(__HIP__MI300__) && \
|
||||||
|
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||||
|
float2 res;
|
||||||
|
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||||
|
res.x = f2[0];
|
||||||
|
res.y = f2[1];
|
||||||
|
return res;
|
||||||
|
#else
|
||||||
|
float2 res;
|
||||||
|
res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
|
||||||
|
res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
||||||
|
return res;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> float4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ Float4_
|
||||||
|
vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
|
||||||
|
Float4_ res;
|
||||||
|
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
||||||
|
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> float8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
|
||||||
|
Float4_ tmp1, tmp2;
|
||||||
|
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
||||||
|
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
||||||
|
Float8_ res;
|
||||||
|
res.x = tmp1.x;
|
||||||
|
res.y = tmp1.y;
|
||||||
|
res.z = tmp2.x;
|
||||||
|
res.w = tmp2.y;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// half -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t
|
||||||
|
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
|
||||||
|
__half_raw tmp;
|
||||||
|
tmp.x = a;
|
||||||
|
|
||||||
|
hip_fp8 f8{static_cast<float>(tmp.data)};
|
||||||
|
return f8.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
// bf16 -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t
|
||||||
|
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
|
||||||
|
hip_fp8 res{__bfloat162float(a)};
|
||||||
|
return res.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
// float -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
|
||||||
|
hip_fp8 f8(a);
|
||||||
|
return f8.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> float4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float4
|
||||||
|
vec_conversion<float4, uint32_t>(const uint32_t& a) {
|
||||||
|
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||||
|
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// float2 -> half2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint32_t
|
||||||
|
vec_conversion<uint32_t, float2>(const float2& a) {
|
||||||
|
union {
|
||||||
|
half2 float16;
|
||||||
|
uint32_t uint32;
|
||||||
|
};
|
||||||
|
|
||||||
|
float16 = __float22half2_rn(a);
|
||||||
|
return uint32;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Float4 -> half2x2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
|
||||||
|
uint2 b;
|
||||||
|
float2 val;
|
||||||
|
val.x = a.x.x;
|
||||||
|
val.y = a.x.y;
|
||||||
|
b.x = vec_conversion<uint32_t, float2>(val);
|
||||||
|
|
||||||
|
val.x = a.y.x;
|
||||||
|
val.y = a.y.y;
|
||||||
|
b.y = vec_conversion<uint32_t, float2>(val);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Float4 -> float4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
|
||||||
|
float4 b;
|
||||||
|
b.x = a.x.x;
|
||||||
|
b.y = a.x.y;
|
||||||
|
b.z = a.y.x;
|
||||||
|
b.w = a.y.y;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Float8 -> half2x4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
|
||||||
|
uint4 b;
|
||||||
|
b.x = vec_conversion<uint32_t, float2>(a.x);
|
||||||
|
b.y = vec_conversion<uint32_t, float2>(a.y);
|
||||||
|
b.z = vec_conversion<uint32_t, float2>(a.z);
|
||||||
|
b.w = vec_conversion<uint32_t, float2>(a.w);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// float2 -> bfloat162
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ __nv_bfloat162
|
||||||
|
vec_conversion<__nv_bfloat162, float2>(const float2& a) {
|
||||||
|
__nv_bfloat162 b = __float22bfloat162_rn(a);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Float4 -> bfloat162x2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_4_t
|
||||||
|
vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
|
||||||
|
bf16_4_t b;
|
||||||
|
b.x = __float22bfloat162_rn(a.x);
|
||||||
|
b.y = __float22bfloat162_rn(a.y);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Float8 -> bfloat162x4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_8_t
|
||||||
|
vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
|
||||||
|
bf16_8_t b;
|
||||||
|
b.x = __float22bfloat162_rn(a.x);
|
||||||
|
b.y = __float22bfloat162_rn(a.y);
|
||||||
|
b.z = __float22bfloat162_rn(a.z);
|
||||||
|
b.w = __float22bfloat162_rn(a.w);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Scaled and vectorized conversions, for data exchange between high and low
|
||||||
|
precision domains
|
||||||
|
|
||||||
|
Convention of the scale in API, e.g: FP8_data = Quantization(
|
||||||
|
High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
|
||||||
|
scale => HP
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
// fp8 -> half
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint16_t
|
||||||
|
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) {
|
||||||
|
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||||
|
__half_raw res;
|
||||||
|
res.data = static_cast<float>(f8) * scale;
|
||||||
|
return res.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> half2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
|
||||||
|
const uint16_t& a, const float scale) {
|
||||||
|
#if defined(__HIP__MI300__) && \
|
||||||
|
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||||
|
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||||
|
union {
|
||||||
|
__half2_raw h2r;
|
||||||
|
uint32_t ui32;
|
||||||
|
} tmp;
|
||||||
|
tmp.h2r.x.data = f2[0] * scale;
|
||||||
|
tmp.h2r.y.data = f2[1] * scale;
|
||||||
|
return tmp.ui32;
|
||||||
|
#else
|
||||||
|
union {
|
||||||
|
uint16_t u16[2];
|
||||||
|
uint32_t u32;
|
||||||
|
} tmp;
|
||||||
|
|
||||||
|
tmp.u16[0] =
|
||||||
|
scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||||
|
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(
|
||||||
|
static_cast<uint8_t>(a >> 8U), scale);
|
||||||
|
return tmp.u32;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> half2x2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint2
|
||||||
|
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) {
|
||||||
|
union {
|
||||||
|
uint2 u32x2;
|
||||||
|
uint32_t u32[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
||||||
|
tmp.u32[1] =
|
||||||
|
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||||
|
return tmp.u32x2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> half2x4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint4
|
||||||
|
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) {
|
||||||
|
union {
|
||||||
|
uint4 u64x2;
|
||||||
|
uint2 u64[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
|
||||||
|
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
|
||||||
|
return tmp.u64x2;
|
||||||
|
}
|
||||||
|
|
||||||
|
using __nv_bfloat16 = __hip_bfloat16;
|
||||||
|
|
||||||
|
// fp8 -> __nv_bfloat16
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ __nv_bfloat16
|
||||||
|
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a,
|
||||||
|
const float scale) {
|
||||||
|
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||||
|
float f{f8};
|
||||||
|
return __float2bfloat16(f * scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
using __nv_bfloat162 = __hip_bfloat162;
|
||||||
|
|
||||||
|
// fp8x2 -> __nv_bfloat162
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ __nv_bfloat162
|
||||||
|
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
|
||||||
|
const float scale) {
|
||||||
|
__nv_bfloat162 res;
|
||||||
|
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
|
||||||
|
res.y =
|
||||||
|
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> bf16_4_t
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
||||||
|
const uint32_t& a, const float scale) {
|
||||||
|
bf16_4_t res;
|
||||||
|
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
|
||||||
|
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
||||||
|
scale);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> bf16_8_t
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_8_t
|
||||||
|
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
|
||||||
|
bf16_4_t tmp1, tmp2;
|
||||||
|
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
|
||||||
|
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
|
||||||
|
bf16_8_t res;
|
||||||
|
res.x = tmp1.x;
|
||||||
|
res.y = tmp1.y;
|
||||||
|
res.z = tmp2.x;
|
||||||
|
res.w = tmp2.y;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> float
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
||||||
|
const uint8_t& a, const float scale) {
|
||||||
|
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
||||||
|
return static_cast<float>(fp8) * scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> float2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float2
|
||||||
|
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale) {
|
||||||
|
#if defined(__HIP__MI300__) && \
|
||||||
|
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||||
|
float2 res;
|
||||||
|
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||||
|
res.x = f2[0] * scale;
|
||||||
|
res.y = f2[1] * scale;
|
||||||
|
return res;
|
||||||
|
#else
|
||||||
|
float2 res;
|
||||||
|
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||||
|
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U),
|
||||||
|
scale);
|
||||||
|
return res;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> float4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ Float4_
|
||||||
|
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
|
||||||
|
Float4_ res;
|
||||||
|
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
|
||||||
|
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> float8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ Float8_
|
||||||
|
scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
|
||||||
|
Float4_ tmp1, tmp2;
|
||||||
|
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
|
||||||
|
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
|
||||||
|
Float8_ res;
|
||||||
|
res.x = tmp1.x;
|
||||||
|
res.y = tmp1.y;
|
||||||
|
res.z = tmp2.x;
|
||||||
|
res.w = tmp2.y;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Quantize(HP / scale) => FP8 */
|
||||||
|
|
||||||
|
// TODO(Hai): vectorized to add
|
||||||
|
|
||||||
|
// half -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t
|
||||||
|
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale) {
|
||||||
|
__half_raw tmp;
|
||||||
|
tmp.x = a;
|
||||||
|
|
||||||
|
hip_fp8 f8{static_cast<float>(tmp.data) / scale};
|
||||||
|
return f8.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
// bf16 -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
||||||
|
const __nv_bfloat16& a, const float scale) {
|
||||||
|
hip_fp8 res{__bfloat162float(a) / scale};
|
||||||
|
return res.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
// float -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t
|
||||||
|
scaled_vec_conversion<uint8_t, float>(const float& a, const float scale) {
|
||||||
|
hip_fp8 f8(a / scale);
|
||||||
|
return f8.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> float4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float4
|
||||||
|
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale) {
|
||||||
|
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
||||||
|
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
#endif // ENABLE_FP8
|
||||||
|
|
||||||
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||||
|
__inline__ __device__ Tout convert(const Tin& x) {
|
||||||
|
#ifdef ENABLE_FP8
|
||||||
|
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||||
|
return vec_conversion<Tout, Tin>(x);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||||
|
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
||||||
|
#ifdef ENABLE_FP8
|
||||||
|
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||||
|
return scaled_vec_conversion<Tout, Tin>(x, scale);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The following macro is used to dispatch the conversion function based on
|
||||||
|
// the data type of the key and value cache. The FN is a macro that calls a
|
||||||
|
// function with template<typename scalar_t, typename cache_t,
|
||||||
|
// Fp8KVCacheDataType kv_dt>.
|
||||||
|
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
||||||
|
if (KV_DTYPE == "auto") { \
|
||||||
|
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||||
|
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||||
|
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||||
|
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
|
||||||
|
} else { \
|
||||||
|
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||||
|
} \
|
||||||
|
} else { \
|
||||||
|
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
|
||||||
|
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||||
|
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||||
|
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||||
|
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||||
|
} else { \
|
||||||
|
TORCH_CHECK(false, \
|
||||||
|
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||||
|
} \
|
||||||
|
} else { \
|
||||||
|
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fp8
|
||||||
|
#endif // USE_ROCM
|
||||||
|
} // namespace vllm
|
||||||
@@ -1,167 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#ifdef __HIPCC__
|
|
||||||
#include <hip/hip_runtime.h>
|
|
||||||
#else
|
|
||||||
#include <type_traits>
|
|
||||||
#include <stdint.h>
|
|
||||||
#include <math.h>
|
|
||||||
#include <iostream>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include "hip_float8_impl.h"
|
|
||||||
|
|
||||||
struct alignas(1) hip_fp8
|
|
||||||
{
|
|
||||||
struct from_bits_t
|
|
||||||
{
|
|
||||||
};
|
|
||||||
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); }
|
|
||||||
uint8_t data;
|
|
||||||
|
|
||||||
hip_fp8() = default;
|
|
||||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
|
|
||||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
|
|
||||||
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
|
|
||||||
: data(v)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef __HIP__MI300__
|
|
||||||
// NOTE: ON-DEVICE... always optimal bias
|
|
||||||
explicit HIP_FP8_DEVICE hip_fp8(float v)
|
|
||||||
: data(hip_fp8_impl::to_fp8_from_fp32(v))
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
|
|
||||||
: hip_fp8(static_cast<float>(v))
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
// Host only implementation using s/w simulation
|
|
||||||
explicit HIP_FP8_HOST
|
|
||||||
#else // __HIP__MI300__
|
|
||||||
// both Host and DEVICE for non-MI300 using s/w simulation
|
|
||||||
explicit HIP_FP8_HOST_DEVICE
|
|
||||||
#endif // __HIP__MI300__
|
|
||||||
hip_fp8(float v)
|
|
||||||
{
|
|
||||||
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, true /*clip*/>(v);
|
|
||||||
}
|
|
||||||
|
|
||||||
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
|
|
||||||
: hip_fp8(static_cast<float>(v))
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef __HIP__MI300__
|
|
||||||
// upcast using device specific intrinsic
|
|
||||||
explicit inline HIP_FP8_DEVICE operator float() const
|
|
||||||
{
|
|
||||||
float fval;
|
|
||||||
uint32_t i32val = static_cast<uint32_t>(data);
|
|
||||||
|
|
||||||
// upcast
|
|
||||||
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
|
|
||||||
|
|
||||||
return fval;
|
|
||||||
}
|
|
||||||
|
|
||||||
explicit inline HIP_FP8_HOST operator float() const
|
|
||||||
#else // __HIP__MI300__
|
|
||||||
explicit inline HIP_FP8_HOST_DEVICE operator float() const
|
|
||||||
#endif // __HIP__MI300__
|
|
||||||
{
|
|
||||||
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(data);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace std
|
|
||||||
{
|
|
||||||
inline hip_fp8 sin(hip_fp8 a)
|
|
||||||
{
|
|
||||||
return hip_fp8(sinf(float(a)));
|
|
||||||
}
|
|
||||||
inline hip_fp8 cos(hip_fp8 a)
|
|
||||||
{
|
|
||||||
return hip_fp8(cosf(float(a)));
|
|
||||||
}
|
|
||||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a)
|
|
||||||
{
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
} // namespace std
|
|
||||||
|
|
||||||
// Special operator overloading
|
|
||||||
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8)
|
|
||||||
{
|
|
||||||
return os << float(f8);
|
|
||||||
}
|
|
||||||
|
|
||||||
// all + operator overloading with mixed types
|
|
||||||
// mixed types, always converts to f32, does computation in f32, and returns float
|
|
||||||
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b)
|
|
||||||
{
|
|
||||||
return (fa + float(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb)
|
|
||||||
{
|
|
||||||
return (float(a) + fb);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b)
|
|
||||||
{
|
|
||||||
return hip_fp8(float(a) + float(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b)
|
|
||||||
{
|
|
||||||
return a = hip_fp8(float(a) + float(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
// overloading multiplication, always returns float,
|
|
||||||
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b)
|
|
||||||
{
|
|
||||||
return float(a) * float(b);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b)
|
|
||||||
{
|
|
||||||
return (a * float(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b)
|
|
||||||
{
|
|
||||||
return (float(a) * b);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b)
|
|
||||||
{
|
|
||||||
return ((float)a * float(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b)
|
|
||||||
{
|
|
||||||
return ((float)a * float(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
// overloading for compare
|
|
||||||
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b)
|
|
||||||
{
|
|
||||||
return (a.data == b.data);
|
|
||||||
}
|
|
||||||
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b)
|
|
||||||
{
|
|
||||||
return (a.data != b.data);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b)
|
|
||||||
{
|
|
||||||
return static_cast<float>(a) >= static_cast<float>(b);
|
|
||||||
}
|
|
||||||
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b)
|
|
||||||
{
|
|
||||||
return static_cast<float>(a) > static_cast<float>(b);
|
|
||||||
}
|
|
||||||
@@ -1,316 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#if defined(__HIPCC__) && (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
|
||||||
#define __HIP__MI300__
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef __HIPCC__
|
|
||||||
#define HIP_FP8_HOST_DEVICE __host__ __device__
|
|
||||||
#define HIP_FP8_HOST __host__
|
|
||||||
#define HIP_FP8_DEVICE __device__
|
|
||||||
#else
|
|
||||||
#define HIP_FP8_HOST_DEVICE
|
|
||||||
#define HIP_FP8_HOST
|
|
||||||
#define HIP_FP8_DEVICE
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace hip_fp8_impl
|
|
||||||
{
|
|
||||||
|
|
||||||
#ifdef __HIP__MI300__
|
|
||||||
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v)
|
|
||||||
{
|
|
||||||
uint8_t i8data;
|
|
||||||
union {
|
|
||||||
float fval;
|
|
||||||
uint32_t i32val;
|
|
||||||
uint8_t i8val[4]; // NOTE: not endian independent
|
|
||||||
} val;
|
|
||||||
|
|
||||||
uint32_t ival = 0;
|
|
||||||
val.fval = v;
|
|
||||||
|
|
||||||
if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping
|
|
||||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
|
|
||||||
false); // false -> WORD0
|
|
||||||
val.i32val = ival;
|
|
||||||
i8data = val.i8val[0];
|
|
||||||
|
|
||||||
return i8data;
|
|
||||||
}
|
|
||||||
#endif // __HIP__MI300__
|
|
||||||
|
|
||||||
HIP_FP8_HOST inline int clz(uint32_t x)
|
|
||||||
{
|
|
||||||
return __builtin_clz(x);
|
|
||||||
}
|
|
||||||
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
|
|
||||||
HIP_FP8_DEVICE inline int clz(uint32_t x)
|
|
||||||
{
|
|
||||||
return __clz(x);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
|
|
||||||
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0)
|
|
||||||
{
|
|
||||||
#ifdef __HIPCC__
|
|
||||||
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
|
||||||
#else
|
|
||||||
constexpr bool is_half = false;
|
|
||||||
#endif
|
|
||||||
constexpr bool is_float = std::is_same<T, float>::value;
|
|
||||||
static_assert(wm + we == 7, "wm+we==7");
|
|
||||||
static_assert(is_half || is_float, "Only half and float can be cast to f8");
|
|
||||||
|
|
||||||
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
|
|
||||||
uint32_t x;
|
|
||||||
if (sizeof(T) == 4) {
|
|
||||||
x = reinterpret_cast<uint32_t&>(_x);
|
|
||||||
} else {
|
|
||||||
x = reinterpret_cast<uint16_t&>(_x);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t head, mantissa;
|
|
||||||
int exponent, bias;
|
|
||||||
uint32_t sign;
|
|
||||||
|
|
||||||
if (sizeof(T) == 4) {
|
|
||||||
head = x & 0xFF800000;
|
|
||||||
mantissa = x & 0x7FFFFF;
|
|
||||||
exponent = (head >> 23) & 0xFF;
|
|
||||||
sign = head >> 31;
|
|
||||||
bias = 127;
|
|
||||||
} else {
|
|
||||||
head = x & 0xFC00;
|
|
||||||
mantissa = x & 0x3FF;
|
|
||||||
exponent = (head >> 10) & 0x1F;
|
|
||||||
sign = head >> 15;
|
|
||||||
bias = 15;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
|
|
||||||
|
|
||||||
// Deal with inf and NaNs
|
|
||||||
if (negative_zero_nan) {
|
|
||||||
if (sizeof(T) == 4) {
|
|
||||||
if ((x & 0x7F800000) == 0x7F800000) {
|
|
||||||
return 0x80;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// if(__hisinf(x) || __hisnan(x))
|
|
||||||
if ((x & 0x7C00) == 0x7C00) {
|
|
||||||
return 0x80;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (sizeof(T) == 4) {
|
|
||||||
if ((x & 0x7F800000) == 0x7F800000) {
|
|
||||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if ((x & 0x7C00) == 0x7C00) {
|
|
||||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (x == 0) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// First need to check if it is normal or denorm as there is a difference of
|
|
||||||
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
|
|
||||||
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
|
|
||||||
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
|
|
||||||
// need to check whether there is carry and adjust exponent and mantissa again
|
|
||||||
|
|
||||||
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
|
|
||||||
// bits
|
|
||||||
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
|
||||||
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
|
|
||||||
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
|
||||||
// f8_exponent is the converted f8 exponent with bias encoding
|
|
||||||
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
|
||||||
// the difference needs to be adjusted and mantissa shifted
|
|
||||||
int act_exponent, f8_exponent, exponent_diff;
|
|
||||||
|
|
||||||
if (exponent == 0) { // fp32/fp16 is in denormal.
|
|
||||||
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
|
|
||||||
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
|
|
||||||
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
|
|
||||||
exponent bias 16. It means that there are some numbers in fp16 denormal but they
|
|
||||||
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
|
||||||
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
|
||||||
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
|
|
||||||
act_exponent = exponent - bias + 1;
|
|
||||||
exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
|
||||||
} else { // fp32/fp16 is normal with implicit 1
|
|
||||||
act_exponent = exponent - bias;
|
|
||||||
if (act_exponent <= f8_denormal_act_exponent) {
|
|
||||||
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
|
|
||||||
range. For example fp8 nanoo mode, denormal exponent is -7, but if the
|
|
||||||
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
|
|
||||||
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
|
|
||||||
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
|
||||||
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
|
||||||
} else { // both fp32/fp16 and f8 are in normal range
|
|
||||||
exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
|
|
||||||
// for this case,
|
|
||||||
// act_exponent could be larger. Just that it does not need shift mantissa
|
|
||||||
}
|
|
||||||
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
|
|
||||||
}
|
|
||||||
|
|
||||||
bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
|
|
||||||
static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
|
|
||||||
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
|
|
||||||
done before we shift right as shift right could rip off some residual part
|
|
||||||
and make something not midpoint look like midpoint. For example, the fp16
|
|
||||||
number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
|
|
||||||
shift right by 4 bits, it would look like midpoint.
|
|
||||||
*/
|
|
||||||
|
|
||||||
if (exponent_diff > 0) {
|
|
||||||
mantissa >>= exponent_diff;
|
|
||||||
} else if (exponent_diff == -1) {
|
|
||||||
mantissa <<= -exponent_diff;
|
|
||||||
}
|
|
||||||
bool implicit_one = mantissa & (1 << mfmt);
|
|
||||||
// if there is no implicit 1, it means the f8 is denormal and need to adjust
|
|
||||||
// to denorm exponent
|
|
||||||
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
|
|
||||||
|
|
||||||
// Now we have the exponent and mantissa adjusted
|
|
||||||
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
|
|
||||||
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit that
|
|
||||||
// is not truncated is 1
|
|
||||||
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
|
|
||||||
|
|
||||||
// Now we deal with overflow
|
|
||||||
if (f8_exponent == 0) {
|
|
||||||
if ((1 << mfmt) & mantissa) {
|
|
||||||
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if ((1 << (mfmt + 1)) & mantissa) {
|
|
||||||
mantissa >>= 1;
|
|
||||||
f8_exponent++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mantissa >>= (mfmt - wm);
|
|
||||||
|
|
||||||
// above range: quantize to maximum possible float of the same sign
|
|
||||||
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
|
|
||||||
if (f8_exponent > max_exp) {
|
|
||||||
if (clip) {
|
|
||||||
mantissa = (1 << wm) - 1;
|
|
||||||
f8_exponent = max_exp;
|
|
||||||
} else {
|
|
||||||
return signed_inf;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (f8_exponent == 0 && mantissa == 0) {
|
|
||||||
return negative_zero_nan ? 0 : (sign << 7);
|
|
||||||
}
|
|
||||||
mantissa &= (1 << wm) - 1;
|
|
||||||
return (sign << 7) | (f8_exponent << wm) | mantissa;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int we, int wm, typename T = float, bool negative_zero_nan = true>
|
|
||||||
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x)
|
|
||||||
{
|
|
||||||
#ifdef __HIPCC__
|
|
||||||
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
|
||||||
#else
|
|
||||||
constexpr bool is_half = false;
|
|
||||||
#endif
|
|
||||||
constexpr bool is_float = std::is_same<T, float>::value;
|
|
||||||
static_assert(is_half || is_float, "only half and float are supported");
|
|
||||||
|
|
||||||
constexpr int weo = is_half ? 5 : 8;
|
|
||||||
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
|
|
||||||
|
|
||||||
T fInf, fNegInf, fNaN, fNeg0;
|
|
||||||
|
|
||||||
#ifdef __HIPCC__
|
|
||||||
if (is_half) {
|
|
||||||
const uint16_t ihInf = 0x7C00;
|
|
||||||
const uint16_t ihNegInf = 0xFC00;
|
|
||||||
const uint16_t ihNaN = 0x7C01;
|
|
||||||
const uint16_t ihNeg0 = 0x8000;
|
|
||||||
fInf = reinterpret_cast<const _Float16&>(ihInf);
|
|
||||||
fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
|
|
||||||
fNaN = reinterpret_cast<const _Float16&>(ihNaN);
|
|
||||||
fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
|
|
||||||
} else
|
|
||||||
#endif
|
|
||||||
if (is_float) {
|
|
||||||
const uint32_t ifInf = 0x7F800000;
|
|
||||||
const uint32_t ifNegInf = 0xFF800000;
|
|
||||||
const uint32_t ifNaN = 0x7F800001;
|
|
||||||
const uint32_t ifNeg0 = 0x80000000;
|
|
||||||
fInf = reinterpret_cast<const float&>(ifInf);
|
|
||||||
fNegInf = reinterpret_cast<const float&>(ifNegInf);
|
|
||||||
fNaN = reinterpret_cast<const float&>(ifNaN);
|
|
||||||
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (x == 0) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t sign = x >> 7;
|
|
||||||
uint32_t mantissa = x & ((1 << wm) - 1);
|
|
||||||
int exponent = (x & 0x7F) >> wm;
|
|
||||||
if (negative_zero_nan) {
|
|
||||||
if (x == 0x80) {
|
|
||||||
return fNaN;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (x == 0x80) {
|
|
||||||
return fNeg0;
|
|
||||||
}
|
|
||||||
if (exponent == ((1 << we) - 1)) {
|
|
||||||
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
|
|
||||||
if (we == 5 && is_half && !negative_zero_nan) {
|
|
||||||
retval = x << 8;
|
|
||||||
return reinterpret_cast<const T&>(retval);
|
|
||||||
}
|
|
||||||
|
|
||||||
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
|
||||||
|
|
||||||
// subnormal input
|
|
||||||
if (exponent == 0) {
|
|
||||||
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
|
||||||
int sh = 1 + clz(mantissa) - (32 - wm);
|
|
||||||
mantissa <<= sh;
|
|
||||||
exponent += 1 - sh;
|
|
||||||
mantissa &= ((1 << wm) - 1);
|
|
||||||
}
|
|
||||||
exponent += exp_low_cutoff - 1;
|
|
||||||
mantissa <<= wmo - wm;
|
|
||||||
|
|
||||||
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
|
||||||
if (exponent <= 0) {
|
|
||||||
mantissa |= 1 << wmo;
|
|
||||||
mantissa >>= 1 - exponent;
|
|
||||||
exponent = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (sizeof(T) == 2) {
|
|
||||||
retval = (sign << 15) | (exponent << 10) | mantissa;
|
|
||||||
} else {
|
|
||||||
retval = (sign << 31) | (exponent << 23) | mantissa;
|
|
||||||
}
|
|
||||||
return reinterpret_cast<const T&>(retval);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace hip_fp8_impl
|
|
||||||
@@ -1,517 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
#include "hip_float8.h"
|
|
||||||
|
|
||||||
#include <hip/hip_fp16.h>
|
|
||||||
#include <hip/hip_bf16.h>
|
|
||||||
#include <hip/hip_bfloat16.h>
|
|
||||||
|
|
||||||
#include "../../../attention/dtype_float32.cuh"
|
|
||||||
#include "../../../attention/dtype_bfloat16.cuh"
|
|
||||||
|
|
||||||
namespace vllm
|
|
||||||
{
|
|
||||||
namespace fp8_e4m3 {
|
|
||||||
template <typename Tout, typename Tin>
|
|
||||||
__inline__ __device__ Tout vec_conversion(const Tin& x)
|
|
||||||
{
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Tout, typename Tin>
|
|
||||||
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, const float scale)
|
|
||||||
{
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8 -> half
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
|
|
||||||
{
|
|
||||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
|
||||||
__half_raw res;
|
|
||||||
res.data = static_cast<float>(f8);
|
|
||||||
return res.x;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x2 -> half2
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
|
|
||||||
{
|
|
||||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
|
||||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
|
||||||
union {
|
|
||||||
__half2_raw h2r;
|
|
||||||
uint32_t ui32;
|
|
||||||
} tmp;
|
|
||||||
tmp.h2r.x.data = f2[0];
|
|
||||||
tmp.h2r.y.data = f2[1];
|
|
||||||
return tmp.ui32;
|
|
||||||
#else
|
|
||||||
union {
|
|
||||||
uint16_t u16[2];
|
|
||||||
uint32_t u32;
|
|
||||||
} tmp;
|
|
||||||
|
|
||||||
tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
|
|
||||||
tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
|
||||||
return tmp.u32;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x4 -> half2x2
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
|
|
||||||
{
|
|
||||||
union {
|
|
||||||
uint2 u32x2;
|
|
||||||
uint32_t u32[2];
|
|
||||||
} tmp;
|
|
||||||
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
|
||||||
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
|
||||||
return tmp.u32x2;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x8 -> half2x4
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
|
|
||||||
{
|
|
||||||
union {
|
|
||||||
uint4 u64x2;
|
|
||||||
uint2 u64[2];
|
|
||||||
} tmp;
|
|
||||||
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
|
||||||
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
|
||||||
return tmp.u64x2;
|
|
||||||
}
|
|
||||||
|
|
||||||
using __nv_bfloat16 = __hip_bfloat16;
|
|
||||||
|
|
||||||
// fp8 -> __nv_bfloat16
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
|
|
||||||
{
|
|
||||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
|
||||||
float f{f8};
|
|
||||||
return __float2bfloat16(f);
|
|
||||||
}
|
|
||||||
|
|
||||||
using __nv_bfloat162 = __hip_bfloat162;
|
|
||||||
|
|
||||||
// fp8x2 -> __nv_bfloat162
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
|
|
||||||
{
|
|
||||||
__nv_bfloat162 res;
|
|
||||||
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
|
||||||
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x4 -> bf16_4_t
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
|
|
||||||
{
|
|
||||||
bf16_4_t res;
|
|
||||||
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
|
||||||
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x8 -> bf16_8_t
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
|
|
||||||
{
|
|
||||||
bf16_4_t tmp1, tmp2;
|
|
||||||
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
|
||||||
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
|
||||||
bf16_8_t res;
|
|
||||||
res.x = tmp1.x;
|
|
||||||
res.y = tmp1.y;
|
|
||||||
res.z = tmp2.x;
|
|
||||||
res.w = tmp2.y;
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8 -> float
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
|
|
||||||
{
|
|
||||||
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
|
||||||
return static_cast<float>(fp8);
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x2 -> float2
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
|
|
||||||
{
|
|
||||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
|
||||||
float2 res;
|
|
||||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
|
||||||
res.x = f2[0];
|
|
||||||
res.y = f2[1];
|
|
||||||
return res;
|
|
||||||
#else
|
|
||||||
float2 res;
|
|
||||||
res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
|
|
||||||
res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
|
||||||
return res;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x4 -> float4
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
|
|
||||||
{
|
|
||||||
Float4_ res;
|
|
||||||
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
|
||||||
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x8 -> float8
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
|
|
||||||
{
|
|
||||||
Float4_ tmp1, tmp2;
|
|
||||||
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
|
||||||
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
|
||||||
Float8_ res;
|
|
||||||
res.x = tmp1.x;
|
|
||||||
res.y = tmp1.y;
|
|
||||||
res.z = tmp2.x;
|
|
||||||
res.w = tmp2.y;
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// half -> fp8
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
|
|
||||||
{
|
|
||||||
__half_raw tmp;
|
|
||||||
tmp.x = a;
|
|
||||||
|
|
||||||
hip_fp8 f8{static_cast<float>(tmp.data)};
|
|
||||||
return f8.data;
|
|
||||||
}
|
|
||||||
|
|
||||||
// bf16 -> fp8
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
|
|
||||||
{
|
|
||||||
hip_fp8 res{__bfloat162float(a)};
|
|
||||||
return res.data;
|
|
||||||
}
|
|
||||||
|
|
||||||
// float -> fp8
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
|
|
||||||
{
|
|
||||||
hip_fp8 f8(a);
|
|
||||||
return f8.data;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x4 -> float4
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
|
|
||||||
{
|
|
||||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
|
||||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// float2 -> half2
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
|
|
||||||
{
|
|
||||||
union {
|
|
||||||
half2 float16;
|
|
||||||
uint32_t uint32;
|
|
||||||
};
|
|
||||||
|
|
||||||
float16 = __float22half2_rn(a);
|
|
||||||
return uint32;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Float4 -> half2x2
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
|
|
||||||
{
|
|
||||||
uint2 b;
|
|
||||||
float2 val;
|
|
||||||
val.x = a.x.x;
|
|
||||||
val.y = a.x.y;
|
|
||||||
b.x = vec_conversion<uint32_t, float2>(val);
|
|
||||||
|
|
||||||
val.x = a.y.x;
|
|
||||||
val.y = a.y.y;
|
|
||||||
b.y = vec_conversion<uint32_t, float2>(val);
|
|
||||||
return b;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Float4 -> float4
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
|
|
||||||
{
|
|
||||||
float4 b;
|
|
||||||
b.x = a.x.x;
|
|
||||||
b.y = a.x.y;
|
|
||||||
b.z = a.y.x;
|
|
||||||
b.w = a.y.y;
|
|
||||||
return b;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Float8 -> half2x4
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
|
|
||||||
{
|
|
||||||
uint4 b;
|
|
||||||
b.x = vec_conversion<uint32_t, float2>(a.x);
|
|
||||||
b.y = vec_conversion<uint32_t, float2>(a.y);
|
|
||||||
b.z = vec_conversion<uint32_t, float2>(a.z);
|
|
||||||
b.w = vec_conversion<uint32_t, float2>(a.w);
|
|
||||||
return b;
|
|
||||||
}
|
|
||||||
|
|
||||||
// float2 -> bfloat162
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a)
|
|
||||||
{
|
|
||||||
__nv_bfloat162 b = __float22bfloat162_rn(a);
|
|
||||||
return b;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Float4 -> bfloat162x2
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_& a)
|
|
||||||
{
|
|
||||||
bf16_4_t b;
|
|
||||||
b.x = __float22bfloat162_rn(a.x);
|
|
||||||
b.y = __float22bfloat162_rn(a.y);
|
|
||||||
return b;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Float8 -> bfloat162x4
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_& a)
|
|
||||||
{
|
|
||||||
bf16_8_t b;
|
|
||||||
b.x = __float22bfloat162_rn(a.x);
|
|
||||||
b.y = __float22bfloat162_rn(a.y);
|
|
||||||
b.z = __float22bfloat162_rn(a.z);
|
|
||||||
b.w = __float22bfloat162_rn(a.w);
|
|
||||||
return b;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/* Scaled and vectorized conversions, for data exchange between high and low precision domains
|
|
||||||
|
|
||||||
Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale )
|
|
||||||
s.t.
|
|
||||||
Quantize(HP / scale) => FP8
|
|
||||||
Dequant(FP8) * scale => HP
|
|
||||||
|
|
||||||
*/
|
|
||||||
|
|
||||||
// fp8 -> half
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale)
|
|
||||||
{
|
|
||||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
|
||||||
__half_raw res;
|
|
||||||
res.data = static_cast<float>(f8) * scale;
|
|
||||||
return res.x;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x2 -> half2
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, const float scale)
|
|
||||||
{
|
|
||||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
|
||||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
|
||||||
union {
|
|
||||||
__half2_raw h2r;
|
|
||||||
uint32_t ui32;
|
|
||||||
} tmp;
|
|
||||||
tmp.h2r.x.data = f2[0] * scale;
|
|
||||||
tmp.h2r.y.data = f2[1] * scale;
|
|
||||||
return tmp.ui32;
|
|
||||||
#else
|
|
||||||
union {
|
|
||||||
uint16_t u16[2];
|
|
||||||
uint32_t u32;
|
|
||||||
} tmp;
|
|
||||||
|
|
||||||
tmp.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
|
|
||||||
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
|
|
||||||
return tmp.u32;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x4 -> half2x2
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale)
|
|
||||||
{
|
|
||||||
union {
|
|
||||||
uint2 u32x2;
|
|
||||||
uint32_t u32[2];
|
|
||||||
} tmp;
|
|
||||||
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
|
||||||
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
|
||||||
return tmp.u32x2;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x8 -> half2x4
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale)
|
|
||||||
{
|
|
||||||
union {
|
|
||||||
uint4 u64x2;
|
|
||||||
uint2 u64[2];
|
|
||||||
} tmp;
|
|
||||||
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
|
|
||||||
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
|
|
||||||
return tmp.u64x2;
|
|
||||||
}
|
|
||||||
|
|
||||||
using __nv_bfloat16 = __hip_bfloat16;
|
|
||||||
|
|
||||||
// fp8 -> __nv_bfloat16
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, const float scale)
|
|
||||||
{
|
|
||||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
|
||||||
float f{f8};
|
|
||||||
return __float2bfloat16(f * scale);
|
|
||||||
}
|
|
||||||
|
|
||||||
using __nv_bfloat162 = __hip_bfloat162;
|
|
||||||
|
|
||||||
// fp8x2 -> __nv_bfloat162
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, const float scale)
|
|
||||||
{
|
|
||||||
__nv_bfloat162 res;
|
|
||||||
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
|
|
||||||
res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x4 -> bf16_4_t
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, const float scale)
|
|
||||||
{
|
|
||||||
bf16_4_t res;
|
|
||||||
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
|
|
||||||
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x8 -> bf16_8_t
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale)
|
|
||||||
{
|
|
||||||
bf16_4_t tmp1, tmp2;
|
|
||||||
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
|
|
||||||
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
|
|
||||||
bf16_8_t res;
|
|
||||||
res.x = tmp1.x;
|
|
||||||
res.y = tmp1.y;
|
|
||||||
res.z = tmp2.x;
|
|
||||||
res.w = tmp2.y;
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8 -> float
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(const uint8_t& a, const float scale)
|
|
||||||
{
|
|
||||||
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
|
||||||
return static_cast<float>(fp8) * scale;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x2 -> float2
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale)
|
|
||||||
{
|
|
||||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
|
||||||
float2 res;
|
|
||||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
|
||||||
res.x = f2[0] * scale;
|
|
||||||
res.y = f2[1] * scale;
|
|
||||||
return res;
|
|
||||||
#else
|
|
||||||
float2 res;
|
|
||||||
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
|
|
||||||
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
|
|
||||||
return res;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x4 -> float4
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale)
|
|
||||||
{
|
|
||||||
Float4_ res;
|
|
||||||
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
|
|
||||||
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x8 -> float8
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale)
|
|
||||||
{
|
|
||||||
Float4_ tmp1, tmp2;
|
|
||||||
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
|
|
||||||
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
|
|
||||||
Float8_ res;
|
|
||||||
res.x = tmp1.x;
|
|
||||||
res.y = tmp1.y;
|
|
||||||
res.z = tmp2.x;
|
|
||||||
res.w = tmp2.y;
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/* Quantize(HP / scale) => FP8 */
|
|
||||||
|
|
||||||
// TODO(Hai): vectorized to add
|
|
||||||
|
|
||||||
// half -> fp8
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale)
|
|
||||||
{
|
|
||||||
__half_raw tmp;
|
|
||||||
tmp.x = a;
|
|
||||||
|
|
||||||
hip_fp8 f8{static_cast<float>(tmp.data)/scale};
|
|
||||||
return f8.data;
|
|
||||||
}
|
|
||||||
|
|
||||||
// bf16 -> fp8
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a, const float scale)
|
|
||||||
{
|
|
||||||
hip_fp8 res{__bfloat162float(a)/scale};
|
|
||||||
return res.data;
|
|
||||||
}
|
|
||||||
|
|
||||||
// float -> fp8
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(const float& a, const float scale)
|
|
||||||
{
|
|
||||||
hip_fp8 f8(a/scale);
|
|
||||||
return f8.data;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x4 -> float4
|
|
||||||
template <>
|
|
||||||
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale)
|
|
||||||
{
|
|
||||||
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
|
||||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
} // namespace vllm
|
|
||||||
124
csrc/quantization/fp8/common.cu
Normal file
124
csrc/quantization/fp8/common.cu
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "cuda_compat.h"
|
||||||
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||||
|
float old;
|
||||||
|
old = (value >= 0)
|
||||||
|
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
||||||
|
: __uint_as_float(
|
||||||
|
atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||||
|
|
||||||
|
return old;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
|
||||||
|
const scalar_t val, const float scale) {
|
||||||
|
float x = static_cast<float>(val) / scale;
|
||||||
|
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||||
|
return static_cast<c10::Float8_e4m3fn>(r);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the absolute maximum m of the input tensor and store
|
||||||
|
// m / float8_e4m3::max() in *scale. Each thread block performs a
|
||||||
|
// reduction tree and the memory in scale is atomically updated.
|
||||||
|
// So to get the right answer, *scale needs to be initialized to
|
||||||
|
// a value <= 0.0 and we need to wait for all thread blocks to
|
||||||
|
// finish before consuming *scale.
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||||
|
const scalar_t* __restrict__ input,
|
||||||
|
int64_t num_elems) {
|
||||||
|
__shared__ float cache[1024];
|
||||||
|
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
// First store maximum for all values processes by
|
||||||
|
// the current thread in cache[threadIdx.x]
|
||||||
|
scalar_t tmp = 0.0;
|
||||||
|
while (i < num_elems) {
|
||||||
|
float x = static_cast<float>(input[i]);
|
||||||
|
tmp = max(tmp, fabs(x));
|
||||||
|
i += blockDim.x * gridDim.x;
|
||||||
|
}
|
||||||
|
cache[threadIdx.x] = tmp;
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Now perform parallel reduction within the thread block
|
||||||
|
int ib = blockDim.x / 2;
|
||||||
|
while (ib != 0) {
|
||||||
|
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
|
||||||
|
cache[threadIdx.x] = cache[threadIdx.x + ib];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
ib /= 2;
|
||||||
|
}
|
||||||
|
// Finally, since cache[0] contains the maximum for this thread block,
|
||||||
|
// atomically write the max to the target location
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
atomicMaxFloat(scale,
|
||||||
|
cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
||||||
|
const scalar_t* __restrict__ input,
|
||||||
|
const float* __restrict__ scale,
|
||||||
|
int64_t num_elems) {
|
||||||
|
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
while (i < num_elems) {
|
||||||
|
out[i] = scaled_fp8_conversion(input[i], *scale);
|
||||||
|
i += blockDim.x * gridDim.x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||||
|
torch::Tensor& input, // [..., d]
|
||||||
|
torch::Tensor& scale) // [1]
|
||||||
|
{
|
||||||
|
int64_t num_tokens = input.numel() / input.size(-1);
|
||||||
|
int64_t num_elems = input.numel();
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(1024);
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
||||||
|
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
|
||||||
|
scale.data_ptr<float>(), num_elems);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||||
|
torch::Tensor& input, // [..., d]
|
||||||
|
torch::Tensor& scale) // [1]
|
||||||
|
{
|
||||||
|
int64_t num_tokens = input.numel() / input.size(-1);
|
||||||
|
int64_t num_elems = input.numel();
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(1024);
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
||||||
|
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
|
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
|
||||||
|
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
|
||||||
|
scale.data_ptr<float>(), num_elems);
|
||||||
|
});
|
||||||
|
}
|
||||||
@@ -1,126 +0,0 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
#include <torch/extension.h>
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
|
|
||||||
#include "cuda_compat.h"
|
|
||||||
#include "dispatch_utils.h"
|
|
||||||
|
|
||||||
namespace vllm {
|
|
||||||
|
|
||||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
|
||||||
float old;
|
|
||||||
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) :
|
|
||||||
__uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
|
||||||
|
|
||||||
return old;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute the absolute maximum m of the input tensor and store
|
|
||||||
// m / float8_e4m3::max() in *scale. Each thread block performs a
|
|
||||||
// reduction tree and the memory in scale is atomically updated.
|
|
||||||
// So to get the right answer, *scale needs to be initialized to
|
|
||||||
// a value <= 0.0 and we need to wait for all thread blocks to
|
|
||||||
// finish before consuming *scale.
|
|
||||||
template<typename scalar_t>
|
|
||||||
__global__ void segmented_max_reduction(
|
|
||||||
float* __restrict__ scale,
|
|
||||||
const scalar_t* __restrict__ input,
|
|
||||||
int64_t num_elems) {
|
|
||||||
__shared__ float cache[1024];
|
|
||||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
|
||||||
|
|
||||||
// First store maximum for all values processes by
|
|
||||||
// the current thread in cache[threadIdx.x]
|
|
||||||
scalar_t tmp = 0.0;
|
|
||||||
while (i < num_elems) {
|
|
||||||
float x = static_cast<float>(input[i]);
|
|
||||||
tmp = max(tmp, fabs(x));
|
|
||||||
i += blockDim.x * gridDim.x;
|
|
||||||
}
|
|
||||||
cache[threadIdx.x] = tmp;
|
|
||||||
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// Now perform parallel reduction within the thread block
|
|
||||||
int ib = blockDim.x / 2;
|
|
||||||
while (ib != 0) {
|
|
||||||
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
|
|
||||||
cache[threadIdx.x] = cache[threadIdx.x + ib];
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
ib /= 2;
|
|
||||||
}
|
|
||||||
// Finally, since cache[0] contains the maximum for this thread block,
|
|
||||||
// atomically write the max to the target location
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
atomicMaxFloat(scale, cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename scalar_t>
|
|
||||||
__global__ void scaled_fp8_quant_kernel(
|
|
||||||
c10::Float8_e4m3fn* __restrict__ out,
|
|
||||||
const scalar_t* __restrict__ input,
|
|
||||||
const float* __restrict__ scale,
|
|
||||||
int64_t num_elems) {
|
|
||||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
|
||||||
while (i < num_elems) {
|
|
||||||
out[i] = static_cast<c10::Float8_e4m3fn>(input[i] / *scale);
|
|
||||||
i += blockDim.x * gridDim.x;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace vllm
|
|
||||||
|
|
||||||
void static_scaled_fp8_quant(
|
|
||||||
torch::Tensor& out, // [..., d]
|
|
||||||
torch::Tensor& input, // [..., d]
|
|
||||||
torch::Tensor& scale) // [1]
|
|
||||||
{
|
|
||||||
int64_t num_tokens = input.numel() / input.size(-1);
|
|
||||||
int64_t num_elems = input.numel();
|
|
||||||
dim3 grid(num_tokens);
|
|
||||||
dim3 block(1024);
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
|
||||||
input.scalar_type(),
|
|
||||||
"scaled_fp8_quant_kernel",
|
|
||||||
[&] {
|
|
||||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
|
||||||
out.data_ptr<c10::Float8_e4m3fn>(),
|
|
||||||
input.data_ptr<scalar_t>(),
|
|
||||||
scale.data_ptr<float>(),
|
|
||||||
num_elems);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
void dynamic_scaled_fp8_quant(
|
|
||||||
torch::Tensor& out, // [..., d]
|
|
||||||
torch::Tensor& input, // [..., d]
|
|
||||||
torch::Tensor& scale) // [1]
|
|
||||||
{
|
|
||||||
int64_t num_tokens = input.numel() / input.size(-1);
|
|
||||||
int64_t num_elems = input.numel();
|
|
||||||
dim3 grid(num_tokens);
|
|
||||||
dim3 block(1024);
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
|
||||||
input.scalar_type(),
|
|
||||||
"scaled_fp8_quant_kernel",
|
|
||||||
[&] {
|
|
||||||
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
|
||||||
scale.data_ptr<float>(),
|
|
||||||
input.data_ptr<scalar_t>(),
|
|
||||||
num_elems);
|
|
||||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
|
||||||
out.data_ptr<c10::Float8_e4m3fn>(),
|
|
||||||
input.data_ptr<scalar_t>(),
|
|
||||||
scale.data_ptr<float>(),
|
|
||||||
num_elems);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
570
csrc/quantization/fp8/nvidia/quant_utils.cuh
Normal file
570
csrc/quantization/fp8/nvidia/quant_utils.cuh
Normal file
@@ -0,0 +1,570 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "../../../attention/attention_dtypes.h"
|
||||||
|
#include <assert.h>
|
||||||
|
#include <float.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
|
||||||
|
namespace fp8 {
|
||||||
|
#ifdef ENABLE_FP8
|
||||||
|
|
||||||
|
#if 0 // Disable the following code to reduce the binary size.
|
||||||
|
template <typename Tout, typename Tin>
|
||||||
|
__inline__ __device__ Tout
|
||||||
|
vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> half
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(
|
||||||
|
const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||||
|
return res.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> half2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(
|
||||||
|
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
union {
|
||||||
|
uint16_t u16[2];
|
||||||
|
uint32_t u32;
|
||||||
|
} tmp;
|
||||||
|
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
|
||||||
|
tmp.u16[0] = res.x;
|
||||||
|
tmp.u16[1] = res.y;
|
||||||
|
return tmp.u32;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> half2x2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(
|
||||||
|
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
union {
|
||||||
|
uint2 u32x2;
|
||||||
|
uint32_t u32[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a, fp8_type);
|
||||||
|
tmp.u32[1] =
|
||||||
|
vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
||||||
|
return tmp.u32x2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> half2x4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(
|
||||||
|
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
union {
|
||||||
|
uint4 u64x2;
|
||||||
|
uint2 u64[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x, fp8_type);
|
||||||
|
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y, fp8_type);
|
||||||
|
return tmp.u64x2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> __nv_bfloat16
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(
|
||||||
|
const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
// Note there is no direct convert function from fp8 to bf16.
|
||||||
|
// fp8 -> half
|
||||||
|
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||||
|
// half -> float -> bf16
|
||||||
|
float tmp = half_to_float(res.x);
|
||||||
|
return __float2bfloat16(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> __nv_bfloat162
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(
|
||||||
|
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
__nv_bfloat162 res;
|
||||||
|
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type);
|
||||||
|
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> bf16_4_t
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(
|
||||||
|
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
bf16_4_t res;
|
||||||
|
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type);
|
||||||
|
res.y =
|
||||||
|
vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> bf16_8_t
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(
|
||||||
|
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
bf16_4_t tmp1, tmp2;
|
||||||
|
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x, fp8_type);
|
||||||
|
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y, fp8_type);
|
||||||
|
bf16_8_t res;
|
||||||
|
res.x = tmp1.x;
|
||||||
|
res.y = tmp1.y;
|
||||||
|
res.z = tmp2.x;
|
||||||
|
res.w = tmp2.y;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> float
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float
|
||||||
|
vec_conversion<float, uint8_t>(const uint8_t &a,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
// fp8 -> half
|
||||||
|
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type);
|
||||||
|
// half -> float
|
||||||
|
return half_to_float(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> float2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(
|
||||||
|
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
// fp8x2 -> half2
|
||||||
|
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a, fp8_type);
|
||||||
|
// half2 -> float2
|
||||||
|
return half2_to_float2(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> float4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(
|
||||||
|
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
Float4_ res;
|
||||||
|
res.x = vec_conversion<float2, uint16_t>((uint16_t)a, fp8_type);
|
||||||
|
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> float8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(
|
||||||
|
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
Float4_ tmp1, tmp2;
|
||||||
|
tmp1 = vec_conversion<Float4_, uint32_t>(a.x, fp8_type);
|
||||||
|
tmp2 = vec_conversion<Float4_, uint32_t>(a.y, fp8_type);
|
||||||
|
Float8_ res;
|
||||||
|
res.x = tmp1.x;
|
||||||
|
res.y = tmp1.y;
|
||||||
|
res.z = tmp2.x;
|
||||||
|
res.w = tmp2.y;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// half -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(
|
||||||
|
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
__half_raw tmp;
|
||||||
|
tmp.x = a;
|
||||||
|
__nv_fp8_storage_t res =
|
||||||
|
__nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type);
|
||||||
|
return (uint8_t)res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// bf16 -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(
|
||||||
|
const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
|
||||||
|
__nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
|
||||||
|
return (uint8_t)res;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// float -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(
|
||||||
|
const float &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type);
|
||||||
|
return (uint8_t)res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> float4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(
|
||||||
|
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a, fp8_type);
|
||||||
|
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(
|
||||||
|
const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
union {
|
||||||
|
half2 float16;
|
||||||
|
uint32_t uint32;
|
||||||
|
};
|
||||||
|
|
||||||
|
float16 = __float22half2_rn(a);
|
||||||
|
return uint32;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(
|
||||||
|
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
uint2 b;
|
||||||
|
float2 val;
|
||||||
|
val.x = a.x.x;
|
||||||
|
val.y = a.x.y;
|
||||||
|
b.x = vec_conversion<uint32_t, float2>(val, fp8_type);
|
||||||
|
|
||||||
|
val.x = a.y.x;
|
||||||
|
val.y = a.y.y;
|
||||||
|
b.y = vec_conversion<uint32_t, float2>(val, fp8_type);
|
||||||
|
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float4 vec_conversion<float4, Float4_>(
|
||||||
|
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
float4 b;
|
||||||
|
b.x = a.x.x;
|
||||||
|
b.y = a.x.y;
|
||||||
|
b.z = a.y.x;
|
||||||
|
b.w = a.y.y;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(
|
||||||
|
const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
uint4 b;
|
||||||
|
b.x = vec_conversion<uint32_t, float2>(a.x, fp8_type);
|
||||||
|
b.y = vec_conversion<uint32_t, float2>(a.y, fp8_type);
|
||||||
|
b.z = vec_conversion<uint32_t, float2>(a.z, fp8_type);
|
||||||
|
b.w = vec_conversion<uint32_t, float2>(a.w, fp8_type);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(
|
||||||
|
const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
__nv_bfloat162 b;
|
||||||
|
from_float(b, a);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(
|
||||||
|
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
bf16_4_t b;
|
||||||
|
from_float(b, a);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
|
||||||
|
const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
bf16_8_t b;
|
||||||
|
from_float(b, a);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/* Scaled and vectorized conversions, for data exchange between high and low
|
||||||
|
precision domains Convention of the scale in API, e.g: FP8_data =
|
||||||
|
Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8
|
||||||
|
Dequant(FP8) * scale => HP
|
||||||
|
*/
|
||||||
|
|
||||||
|
template <typename Tout, typename Tin>
|
||||||
|
__inline__ __device__ Tout scaled_vec_conversion(
|
||||||
|
const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> half
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
|
||||||
|
const uint8_t& a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
__half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||||
|
return float_to_half(half_to_float(tmp.x) * scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> half2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
|
||||||
|
const uint16_t& a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
union {
|
||||||
|
uint16_t u16[2];
|
||||||
|
uint32_t u32;
|
||||||
|
} tmp;
|
||||||
|
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
|
||||||
|
tmp.u16[0] = float_to_half(half_to_float(res.x) * scale);
|
||||||
|
tmp.u16[1] = float_to_half(half_to_float(res.y) * scale);
|
||||||
|
return tmp.u32;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> half2x2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
|
||||||
|
const uint32_t& a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
union {
|
||||||
|
uint2 u32x2;
|
||||||
|
uint32_t u32[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u32[0] =
|
||||||
|
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, fp8_type);
|
||||||
|
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U),
|
||||||
|
scale, fp8_type);
|
||||||
|
return tmp.u32x2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> half2x4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint4
|
||||||
|
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
union {
|
||||||
|
uint4 u64x2;
|
||||||
|
uint2 u64[2];
|
||||||
|
} tmp;
|
||||||
|
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, fp8_type);
|
||||||
|
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, fp8_type);
|
||||||
|
return tmp.u64x2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> __nv_bfloat16
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ __nv_bfloat16
|
||||||
|
scaled_vec_conversion<__nv_bfloat16, uint8_t>(
|
||||||
|
const uint8_t& a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
// Note there is no direct convert function from fp8 to bf16.
|
||||||
|
// fp8 -> half
|
||||||
|
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||||
|
// half -> float -> bf16
|
||||||
|
float tmp = half_to_float(res.x);
|
||||||
|
return __float2bfloat16(tmp * scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> __nv_bfloat162
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ __nv_bfloat162
|
||||||
|
scaled_vec_conversion<__nv_bfloat162, uint16_t>(
|
||||||
|
const uint16_t& a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
__nv_bfloat162 res;
|
||||||
|
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale,
|
||||||
|
fp8_type);
|
||||||
|
res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U),
|
||||||
|
scale, fp8_type);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> bf16_4_t
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
||||||
|
const uint32_t& a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
bf16_4_t res;
|
||||||
|
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale,
|
||||||
|
fp8_type);
|
||||||
|
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
||||||
|
scale, fp8_type);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> bf16_8_t
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
|
||||||
|
const uint2& a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
bf16_4_t tmp1, tmp2;
|
||||||
|
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type);
|
||||||
|
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, fp8_type);
|
||||||
|
bf16_8_t res;
|
||||||
|
res.x = tmp1.x;
|
||||||
|
res.y = tmp1.y;
|
||||||
|
res.z = tmp2.x;
|
||||||
|
res.w = tmp2.y;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8 -> float
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
||||||
|
const uint8_t& a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
// fp8 -> half
|
||||||
|
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||||
|
uint16_t tmp = res.x;
|
||||||
|
|
||||||
|
// half -> float
|
||||||
|
return half_to_float(tmp) * scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x2 -> float2
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
|
||||||
|
const uint16_t& a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
// fp8x2 -> half2
|
||||||
|
uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type);
|
||||||
|
// half2 -> float2
|
||||||
|
return half2_to_float2(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> float4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
|
||||||
|
const uint32_t& a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
Float4_ res;
|
||||||
|
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type);
|
||||||
|
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale,
|
||||||
|
fp8_type);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x8 -> float8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
|
||||||
|
const uint2& a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
Float4_ tmp1, tmp2;
|
||||||
|
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type);
|
||||||
|
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, fp8_type);
|
||||||
|
Float8_ res;
|
||||||
|
res.x = tmp1.x;
|
||||||
|
res.y = tmp1.y;
|
||||||
|
res.z = tmp2.x;
|
||||||
|
res.w = tmp2.y;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// half -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
|
||||||
|
const uint16_t& a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
__nv_fp8_storage_t res =
|
||||||
|
__nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type);
|
||||||
|
return (uint8_t)res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// bf16 -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
||||||
|
const __nv_bfloat16& a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale,
|
||||||
|
__NV_SATFINITE, fp8_type);
|
||||||
|
return (uint8_t)res;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// float -> fp8
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
|
||||||
|
const float& a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
__nv_fp8_storage_t res =
|
||||||
|
__nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type);
|
||||||
|
return (uint8_t)res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp8x4 -> float4
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(
|
||||||
|
const uint32_t& a, const float scale,
|
||||||
|
const __nv_fp8_interpretation_t fp8_type) {
|
||||||
|
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type);
|
||||||
|
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
#endif // ENABLE_FP8
|
||||||
|
|
||||||
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||||
|
__inline__ __device__ Tout convert(const Tin& x) {
|
||||||
|
#if 0 // Disable the following code to reduce the binary size.
|
||||||
|
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||||
|
return vec_conversion<Tout, Tin>(x, __NV_E4M3);
|
||||||
|
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
|
||||||
|
return vec_conversion<Tout, Tin>(x, __NV_E5M2);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||||
|
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
||||||
|
#ifdef ENABLE_FP8
|
||||||
|
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||||
|
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
|
||||||
|
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
|
||||||
|
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The following macro is used to dispatch the conversion function based on
|
||||||
|
// the data type of the key and value cache. The FN is a macro that calls a
|
||||||
|
// function with template<typename scalar_t, typename cache_t,
|
||||||
|
// Fp8KVCacheDataType kv_dt>.
|
||||||
|
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
||||||
|
if (KV_DTYPE == "auto") { \
|
||||||
|
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||||
|
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||||
|
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||||
|
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
|
||||||
|
} else { \
|
||||||
|
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||||
|
} \
|
||||||
|
} else { \
|
||||||
|
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
|
||||||
|
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||||
|
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||||
|
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||||
|
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||||
|
} else { \
|
||||||
|
TORCH_CHECK(false, \
|
||||||
|
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||||
|
} \
|
||||||
|
} else if (KV_DTYPE == "fp8_e5m2") { \
|
||||||
|
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||||
|
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||||
|
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
||||||
|
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||||
|
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
||||||
|
} else { \
|
||||||
|
TORCH_CHECK(false, \
|
||||||
|
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||||
|
} \
|
||||||
|
} else { \
|
||||||
|
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fp8
|
||||||
|
#endif // not USE_ROCM
|
||||||
|
} // namespace vllm
|
||||||
@@ -1,277 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
#include <stdint.h>
|
|
||||||
#include <float.h>
|
|
||||||
#include <type_traits>
|
|
||||||
#include "../../attention/attention_dtypes.h"
|
|
||||||
#include "../../attention/dtype_float32.cuh"
|
|
||||||
#include "../../attention/dtype_float16.cuh"
|
|
||||||
#include "../../attention/dtype_bfloat16.cuh"
|
|
||||||
|
|
||||||
|
|
||||||
namespace vllm {
|
|
||||||
#ifdef ENABLE_FP8_E5M2
|
|
||||||
namespace fp8_e5m2_unscaled {
|
|
||||||
|
|
||||||
template<typename Tout, typename Tin>
|
|
||||||
__inline__ __device__ Tout vec_conversion(const Tin& x)
|
|
||||||
{
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8 -> half
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
|
|
||||||
{
|
|
||||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
|
|
||||||
return res.x;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x2 -> half2
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
|
|
||||||
{
|
|
||||||
union {
|
|
||||||
uint16_t u16[2];
|
|
||||||
uint32_t u32;
|
|
||||||
} tmp;
|
|
||||||
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2);
|
|
||||||
tmp.u16[0] = res.x;
|
|
||||||
tmp.u16[1] = res.y;
|
|
||||||
return tmp.u32;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x4 -> half2x2
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
|
|
||||||
{
|
|
||||||
union {
|
|
||||||
uint2 u32x2;
|
|
||||||
uint32_t u32[2];
|
|
||||||
} tmp;
|
|
||||||
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
|
||||||
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
|
||||||
return tmp.u32x2;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x8 -> half2x4
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
|
|
||||||
{
|
|
||||||
union {
|
|
||||||
uint4 u64x2;
|
|
||||||
uint2 u64[2];
|
|
||||||
} tmp;
|
|
||||||
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
|
||||||
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
|
||||||
return tmp.u64x2;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8 -> __nv_bfloat16
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
|
|
||||||
{
|
|
||||||
// Note there is no direct convert function from fp8 to bf16.
|
|
||||||
// fp8 -> half
|
|
||||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
|
|
||||||
// half -> float -> bf16
|
|
||||||
float tmp = half_to_float(res.x);
|
|
||||||
return __float2bfloat16(tmp);
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x2 -> __nv_bfloat162
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
|
|
||||||
{
|
|
||||||
__nv_bfloat162 res;
|
|
||||||
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
|
||||||
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x4 -> bf16_4_t
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
|
|
||||||
{
|
|
||||||
bf16_4_t res;
|
|
||||||
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
|
||||||
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x8 -> bf16_8_t
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
|
|
||||||
{
|
|
||||||
bf16_4_t tmp1, tmp2;
|
|
||||||
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
|
||||||
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
|
||||||
bf16_8_t res;
|
|
||||||
res.x = tmp1.x;
|
|
||||||
res.y = tmp1.y;
|
|
||||||
res.z = tmp2.x;
|
|
||||||
res.w = tmp2.y;
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8 -> float
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
|
|
||||||
{
|
|
||||||
// fp8 -> half
|
|
||||||
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a);
|
|
||||||
// half -> float
|
|
||||||
return half_to_float(tmp);
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x2 -> float2
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
|
|
||||||
{
|
|
||||||
// fp8x2 -> half2
|
|
||||||
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a);
|
|
||||||
// half2 -> float2
|
|
||||||
return half2_to_float2(tmp);
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x4 -> float4
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
|
|
||||||
{
|
|
||||||
Float4_ res;
|
|
||||||
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
|
||||||
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x8 -> float8
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
|
|
||||||
{
|
|
||||||
Float4_ tmp1, tmp2;
|
|
||||||
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
|
||||||
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
|
||||||
Float8_ res;
|
|
||||||
res.x = tmp1.x;
|
|
||||||
res.y = tmp1.y;
|
|
||||||
res.z = tmp2.x;
|
|
||||||
res.w = tmp2.y;
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// half -> fp8
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
|
|
||||||
{
|
|
||||||
__half_raw tmp;
|
|
||||||
tmp.x = a;
|
|
||||||
__nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2);
|
|
||||||
return (uint8_t)res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// bf16 -> fp8
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
|
|
||||||
{
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
|
||||||
assert(false);
|
|
||||||
#else
|
|
||||||
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2);
|
|
||||||
return (uint8_t)res;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
// float -> fp8
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
|
|
||||||
{
|
|
||||||
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2);
|
|
||||||
return (uint8_t)res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// fp8x4 -> float4
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
|
|
||||||
{
|
|
||||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
|
||||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
|
|
||||||
{
|
|
||||||
union {
|
|
||||||
half2 float16;
|
|
||||||
uint32_t uint32;
|
|
||||||
};
|
|
||||||
|
|
||||||
float16 = __float22half2_rn(a);
|
|
||||||
return uint32;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
|
|
||||||
{
|
|
||||||
uint2 b;
|
|
||||||
float2 val;
|
|
||||||
val.x = a.x.x;
|
|
||||||
val.y = a.x.y;
|
|
||||||
b.x = vec_conversion<uint32_t, float2>(val);
|
|
||||||
|
|
||||||
val.x = a.y.x;
|
|
||||||
val.y = a.y.y;
|
|
||||||
b.y = vec_conversion<uint32_t, float2>(val);
|
|
||||||
|
|
||||||
return b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
|
|
||||||
{
|
|
||||||
float4 b;
|
|
||||||
b.x = a.x.x;
|
|
||||||
b.y = a.x.y;
|
|
||||||
b.z = a.y.x;
|
|
||||||
b.w = a.y.y;
|
|
||||||
return b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
|
|
||||||
{
|
|
||||||
uint4 b;
|
|
||||||
b.x = vec_conversion<uint32_t, float2>(a.x);
|
|
||||||
b.y = vec_conversion<uint32_t, float2>(a.y);
|
|
||||||
b.z = vec_conversion<uint32_t, float2>(a.z);
|
|
||||||
b.w = vec_conversion<uint32_t, float2>(a.w);
|
|
||||||
return b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) {
|
|
||||||
__nv_bfloat162 b;
|
|
||||||
from_float(b, a);
|
|
||||||
return b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) {
|
|
||||||
bf16_4_t b;
|
|
||||||
from_float(b, a);
|
|
||||||
return b;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<>
|
|
||||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) {
|
|
||||||
bf16_8_t b;
|
|
||||||
from_float(b, a);
|
|
||||||
return b;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace fp8_e5m2_unscaled
|
|
||||||
#endif // ENABLE_FP8_E5M2
|
|
||||||
} // namespace vllm
|
|
||||||
@@ -9,40 +9,36 @@ namespace vllm {
|
|||||||
namespace gptq {
|
namespace gptq {
|
||||||
// atomicAdd for half types, to support CC < 7.x
|
// atomicAdd for half types, to support CC < 7.x
|
||||||
|
|
||||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
__device__ __forceinline__ void atomicAdd_half(half* address, half val) {
|
||||||
{
|
unsigned int* address_as_ui =
|
||||||
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
(unsigned int*)((char*)address - ((size_t)address & 2));
|
||||||
unsigned int old = *address_as_ui;
|
unsigned int old = *address_as_ui;
|
||||||
unsigned int assumed;
|
unsigned int assumed;
|
||||||
|
|
||||||
do
|
do {
|
||||||
{
|
|
||||||
assumed = old;
|
assumed = old;
|
||||||
__half_raw hsum;
|
__half_raw hsum;
|
||||||
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||||
half tmpres = __hadd(hsum, val);
|
half tmpres = __hadd(hsum, val);
|
||||||
hsum = __half_raw(tmpres);
|
hsum = __half_raw(tmpres);
|
||||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16)
|
||||||
|
: (old & 0xffff0000) | hsum.x;
|
||||||
old = atomicCAS(address_as_ui, assumed, old);
|
old = atomicCAS(address_as_ui, assumed, old);
|
||||||
}
|
} while (assumed != old);
|
||||||
while (assumed != old);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// atomicAdd for half2 types
|
// atomicAdd for half2 types
|
||||||
|
|
||||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
|
||||||
{
|
|
||||||
unsigned int* address_as_ui = (unsigned int*)address;
|
unsigned int* address_as_ui = (unsigned int*)address;
|
||||||
unsigned int old = *address_as_ui;
|
unsigned int old = *address_as_ui;
|
||||||
unsigned int assumed;
|
unsigned int assumed;
|
||||||
do
|
do {
|
||||||
{
|
|
||||||
assumed = old;
|
assumed = old;
|
||||||
half2 old_val = *((half2*)&old);
|
half2 old_val = *((half2*)&old);
|
||||||
half2 new_val = __hadd2(old_val, val);
|
half2 new_val = __hadd2(old_val, val);
|
||||||
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||||
}
|
} while (assumed != old);
|
||||||
while (assumed != old);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@@ -50,10 +46,14 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
|||||||
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||||
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||||
|
|
||||||
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
__device__ __forceinline__ void atomicAdd(half* address, half val) {
|
||||||
|
atomicAdd_half(address, val);
|
||||||
|
}
|
||||||
|
|
||||||
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
|
||||||
|
atomicAdd_half2(address, val);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
/*
|
/*
|
||||||
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama
|
Adapted from https://github.com/turboderp/exllamav2 and
|
||||||
|
https://github.com/turboderp/exllama
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef _matrix_view_cuh
|
#ifndef _matrix_view_cuh
|
||||||
@@ -13,24 +14,31 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turbo
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
namespace gptq {
|
namespace gptq {
|
||||||
|
|
||||||
class MatrixView_half
|
class MatrixView_half {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
const half* data;
|
const half* data;
|
||||||
const int height;
|
const int height;
|
||||||
const int width;
|
const int width;
|
||||||
|
|
||||||
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
__device__ __forceinline__ MatrixView_half(const half* data, const int height,
|
||||||
: data(data), height(height), width(width)
|
const int width)
|
||||||
{ }
|
: data(data), height(height), width(width) {}
|
||||||
|
|
||||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
__device__ __forceinline__ half item(int row, int column) const {
|
||||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
return data[row * width + column];
|
||||||
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
}
|
||||||
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
__device__ __forceinline__ half2 item_half2(int row, int column) const {
|
||||||
|
return ((half2*)data)[(row * width + column) / 2];
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ half2 item_half2half2(int row, int column) const {
|
||||||
|
return __half2half2(data[row * width + column]);
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ const half* item_ptr(int row, int column) const {
|
||||||
|
return &data[row * width + column];
|
||||||
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
|
__device__ __forceinline__ void item4(half (&items)[4], int row,
|
||||||
{
|
int column) const {
|
||||||
half2* ptr = (half2*)item_ptr(row, column);
|
half2* ptr = (half2*)item_ptr(row, column);
|
||||||
half2 i01 = ptr[0];
|
half2 i01 = ptr[0];
|
||||||
half2 i23 = ptr[1];
|
half2 i23 = ptr[1];
|
||||||
@@ -39,8 +47,8 @@ public:
|
|||||||
items[2] = __low2half(i23);
|
items[2] = __low2half(i23);
|
||||||
items[3] = __high2half(i23);
|
items[3] = __high2half(i23);
|
||||||
}
|
}
|
||||||
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
|
__device__ __forceinline__ void item4_f(float (&items)[4], int row,
|
||||||
{
|
int column) const {
|
||||||
half2* ptr = (half2*)item_ptr(row, column);
|
half2* ptr = (half2*)item_ptr(row, column);
|
||||||
half2 i01 = ptr[0];
|
half2 i01 = ptr[0];
|
||||||
half2 i23 = ptr[1];
|
half2 i23 = ptr[1];
|
||||||
@@ -50,8 +58,8 @@ public:
|
|||||||
items[3] = __half2float(__high2half(i23));
|
items[3] = __half2float(__high2half(i23));
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
|
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row,
|
||||||
{
|
int column) const {
|
||||||
half2* ptr = (half2*)item_ptr(row, column);
|
half2* ptr = (half2*)item_ptr(row, column);
|
||||||
half2 i01 = ptr[0];
|
half2 i01 = ptr[0];
|
||||||
half2 i23 = ptr[1];
|
half2 i23 = ptr[1];
|
||||||
@@ -62,25 +70,34 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MatrixView_half_rw
|
class MatrixView_half_rw {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
half* data;
|
half* data;
|
||||||
const int height;
|
const int height;
|
||||||
const int width;
|
const int width;
|
||||||
|
|
||||||
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height,
|
||||||
: data(data), height(height), width(width)
|
const int width)
|
||||||
{ }
|
: data(data), height(height), width(width) {}
|
||||||
|
|
||||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
__device__ __forceinline__ half item(int row, int column) const {
|
||||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
return data[row * width + column];
|
||||||
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
}
|
||||||
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
__device__ __forceinline__ half2 item_half2(int row, int column) const {
|
||||||
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
return ((half2*)data)[(row * width + column) / 2];
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ half* item_ptr(int row, int column) {
|
||||||
|
return &data[row * width + column];
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ void set(int row, int column, half value) {
|
||||||
|
data[row * width + column] = value;
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ void set_half2(int row, int column, half2 value) {
|
||||||
|
((half2*)data)[(row * width + column) / 2] = value;
|
||||||
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
|
__device__ __forceinline__ void set4(int row, int column, half v0, half v1,
|
||||||
{
|
half v2, half v3) {
|
||||||
half2 v01 = __halves2half2(v0, v1);
|
half2 v01 = __halves2half2(v0, v1);
|
||||||
half2 v23 = __halves2half2(v2, v3);
|
half2 v23 = __halves2half2(v2, v3);
|
||||||
half2* ptr = (half2*)item_ptr(row, column);
|
half2* ptr = (half2*)item_ptr(row, column);
|
||||||
@@ -89,33 +106,32 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MatrixView_q4_row
|
class MatrixView_q4_row {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
const uint32_t* data;
|
const uint32_t* data;
|
||||||
const int height;
|
const int height;
|
||||||
const int width;
|
const int width;
|
||||||
|
|
||||||
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data,
|
||||||
: data(data), height(height), width(width)
|
const int height,
|
||||||
{ }
|
const int width)
|
||||||
|
: data(data), height(height), width(width) {}
|
||||||
|
|
||||||
__device__ __forceinline__ int item(int row, int column) const
|
__device__ __forceinline__ int item(int row, int column) const {
|
||||||
{
|
|
||||||
int shift = (column & 0x07) * 4;
|
int shift = (column & 0x07) * 4;
|
||||||
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
__device__ __forceinline__ void item2(int (&items)[2], int row,
|
||||||
{
|
int column) const {
|
||||||
int shift = (column & 0x07) * 4;
|
int shift = (column & 0x07) * 4;
|
||||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||||
items[0] = d & 0x0f;
|
items[0] = d & 0x0f;
|
||||||
items[1] = (d >> 4) & 0x0f;
|
items[1] = (d >> 4) & 0x0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
__device__ __forceinline__ void item4(int (&items)[4], int row,
|
||||||
{
|
int column) const {
|
||||||
int shift = (column & 0x07) * 4;
|
int shift = (column & 0x07) * 4;
|
||||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||||
items[0] = d & 0x0f;
|
items[0] = d & 0x0f;
|
||||||
@@ -125,54 +141,57 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MatrixView_q4_column
|
class MatrixView_q4_column {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
const uint32_t* data;
|
const uint32_t* data;
|
||||||
const int height;
|
const int height;
|
||||||
const int width;
|
const int width;
|
||||||
|
|
||||||
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
|
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data,
|
||||||
: data(data), height(height), width(width)
|
const int height,
|
||||||
{ }
|
const int width)
|
||||||
|
: data(data), height(height), width(width) {}
|
||||||
|
|
||||||
__device__ __forceinline__ int item(int row, int column) const
|
__device__ __forceinline__ int item(int row, int column) const {
|
||||||
{
|
|
||||||
int shift = (row & 0x07) * 4;
|
int shift = (row & 0x07) * 4;
|
||||||
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
|
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) {
|
||||||
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
|
return data[row / 8 * width + column];
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row,
|
||||||
|
int column) {
|
||||||
|
return &data[row / 8 * width + column];
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MatrixView_q2_row
|
class MatrixView_q2_row {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
const uint32_t* data;
|
const uint32_t* data;
|
||||||
const int height;
|
const int height;
|
||||||
const int width;
|
const int width;
|
||||||
|
|
||||||
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width)
|
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data,
|
||||||
: data(data), height(height), width(width)
|
const int height,
|
||||||
{ }
|
const int width)
|
||||||
|
: data(data), height(height), width(width) {}
|
||||||
|
|
||||||
__device__ __forceinline__ int item(int row, int column) const
|
__device__ __forceinline__ int item(int row, int column) const {
|
||||||
{
|
|
||||||
int shift = (column & 0x0f) * 2;
|
int shift = (column & 0x0f) * 2;
|
||||||
return (data[row * width / 16 + column / 16] >> shift) & 0x03;
|
return (data[row * width / 16 + column / 16] >> shift) & 0x03;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
__device__ __forceinline__ void item2(int (&items)[2], int row,
|
||||||
{
|
int column) const {
|
||||||
int shift = (column & 0x0f) * 2;
|
int shift = (column & 0x0f) * 2;
|
||||||
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
||||||
items[0] = d & 0x03;
|
items[0] = d & 0x03;
|
||||||
items[1] = (d >> 2) & 0x03;
|
items[1] = (d >> 2) & 0x03;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
__device__ __forceinline__ void item4(int (&items)[4], int row,
|
||||||
{
|
int column) const {
|
||||||
int shift = (column & 0x0f) * 2;
|
int shift = (column & 0x0f) * 2;
|
||||||
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
||||||
items[0] = d & 0x03;
|
items[0] = d & 0x03;
|
||||||
@@ -182,26 +201,27 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MatrixView_q3_row
|
class MatrixView_q3_row {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
const uint32_t* data;
|
const uint32_t* data;
|
||||||
const int height;
|
const int height;
|
||||||
const int width;
|
const int width;
|
||||||
|
|
||||||
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width)
|
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data,
|
||||||
: data(data), height(height), width(width)
|
const int height,
|
||||||
{ }
|
const int width)
|
||||||
|
: data(data), height(height), width(width) {}
|
||||||
|
|
||||||
__device__ __forceinline__ int item(int row, int column) const
|
__device__ __forceinline__ int item(int row, int column) const {
|
||||||
{
|
|
||||||
int z_w = column * 3 / 32;
|
int z_w = column * 3 / 32;
|
||||||
int z_mod = column & 0x1f;
|
int z_mod = column & 0x1f;
|
||||||
|
|
||||||
if (z_mod == 10) {
|
if (z_mod == 10) {
|
||||||
return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
|
return (data[row * width * 3 / 32 + z_w] >> 30) |
|
||||||
|
((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
|
||||||
} else if (z_mod == 21) {
|
} else if (z_mod == 21) {
|
||||||
return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
|
return (data[row * width * 3 / 32 + z_w] >> 31) |
|
||||||
|
((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
|
||||||
} else if (z_mod < 10) {
|
} else if (z_mod < 10) {
|
||||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
|
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
|
||||||
} else if (z_mod < 21) {
|
} else if (z_mod < 21) {
|
||||||
@@ -211,18 +231,20 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
__device__ __forceinline__ void item4(int (&items)[4], int row,
|
||||||
{
|
int column) const {
|
||||||
int shift = (column & 0x1f);
|
int shift = (column & 0x1f);
|
||||||
uint32_t d;
|
uint32_t d;
|
||||||
if (shift <= 4) {
|
if (shift <= 4) {
|
||||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
|
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
|
||||||
} else if (shift == 8) {
|
} else if (shift == 8) {
|
||||||
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
|
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) |
|
||||||
|
((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
|
||||||
} else if (shift <= 16) {
|
} else if (shift <= 16) {
|
||||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
|
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
|
||||||
} else if (shift == 20) {
|
} else if (shift == 20) {
|
||||||
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
|
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) |
|
||||||
|
((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
|
||||||
} else {
|
} else {
|
||||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
|
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
|
||||||
}
|
}
|
||||||
@@ -233,33 +255,32 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MatrixView_q8_row
|
class MatrixView_q8_row {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
const uint32_t* data;
|
const uint32_t* data;
|
||||||
const int height;
|
const int height;
|
||||||
const int width;
|
const int width;
|
||||||
|
|
||||||
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width)
|
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data,
|
||||||
: data(data), height(height), width(width)
|
const int height,
|
||||||
{ }
|
const int width)
|
||||||
|
: data(data), height(height), width(width) {}
|
||||||
|
|
||||||
__device__ __forceinline__ int item(int row, int column) const
|
__device__ __forceinline__ int item(int row, int column) const {
|
||||||
{
|
|
||||||
int shift = (column & 0x03) * 8;
|
int shift = (column & 0x03) * 8;
|
||||||
return (data[row * width / 4 + column / 4] >> shift) & 0xff;
|
return (data[row * width / 4 + column / 4] >> shift) & 0xff;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
__device__ __forceinline__ void item2(int (&items)[2], int row,
|
||||||
{
|
int column) const {
|
||||||
int shift = (column & 0x03) * 8;
|
int shift = (column & 0x03) * 8;
|
||||||
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
||||||
items[0] = d & 0xff;
|
items[0] = d & 0xff;
|
||||||
items[1] = (d >> 8) & 0xff;
|
items[1] = (d >> 8) & 0xff;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
__device__ __forceinline__ void item4(int (&items)[4], int row,
|
||||||
{
|
int column) const {
|
||||||
int shift = (column & 0x03) * 2;
|
int shift = (column & 0x03) * 2;
|
||||||
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
||||||
items[0] = d & 0xff;
|
items[0] = d & 0xff;
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user